From 4aa7dbb9c1ae15d336c090ff3124fdb01ff2e4e7 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 19 Dec 2025 17:25:06 +0800 Subject: [PATCH 001/107] ai draft Signed-off-by: Haoyang Li --- src/main/cpp/CMakeLists.txt | 2 + src/main/cpp/src/ProtobufSimpleJni.cpp | 64 +++ src/main/cpp/src/protobuf_simple.cu | 468 ++++++++++++++++++ src/main/cpp/src/protobuf_simple.hpp | 49 ++ .../spark/rapids/jni/ProtobufSimple.java | 64 +++ .../spark/rapids/jni/ProtobufSimpleTest.java | 94 ++++ 6 files changed, 741 insertions(+) create mode 100644 src/main/cpp/src/ProtobufSimpleJni.cpp create mode 100644 src/main/cpp/src/protobuf_simple.cu create mode 100644 src/main/cpp/src/protobuf_simple.hpp create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java create mode 100644 src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index dc6e37d1f8..94012c370c 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -208,6 +208,7 @@ add_library( src/NativeParquetJni.cpp src/NumberConverterJni.cpp src/ParseURIJni.cpp + src/ProtobufSimpleJni.cpp src/RegexRewriteUtilsJni.cpp src/RowConversionJni.cpp src/SparkResourceAdaptorJni.cpp @@ -250,6 +251,7 @@ add_library( src/murmur_hash.cu src/number_converter.cu src/parse_uri.cu + src/protobuf_simple.cu src/regex_rewrite_utils.cu src/row_conversion.cu src/round_float.cu diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufSimpleJni.cpp new file mode 100644 index 0000000000..56f290d228 --- /dev/null +++ b/src/main/cpp/src/ProtobufSimpleJni.cpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cudf_jni_apis.hpp" +#include "dtype_utils.hpp" +#include "protobuf_simple.hpp" + +#include + +extern "C" { + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct( + JNIEnv* env, jclass, jlong binary_input_view, jintArray field_numbers, jintArray type_ids, jintArray type_scales) +{ + JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); + JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); + JNI_NULL_CHECK(env, type_ids, "type_ids is null", 0); + JNI_NULL_CHECK(env, type_scales, "type_scales is null", 0); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const* input = reinterpret_cast(binary_input_view); + cudf::jni::native_jintArray n_field_numbers(env, field_numbers); + cudf::jni::native_jintArray n_type_ids(env, type_ids); + cudf::jni::native_jintArray n_type_scales(env, type_scales); + if (n_field_numbers.size() != n_type_ids.size() || n_field_numbers.size() != n_type_scales.size()) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "fieldNumbers/typeIds/typeScales must be the same length", + 0); + } + + std::vector field_nums(n_field_numbers.begin(), n_field_numbers.end()); + std::vector out_types; + out_types.reserve(n_type_ids.size()); + for (int i = 0; i < n_type_ids.size(); ++i) { + out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); + } + + auto result = + spark_rapids_jni::decode_protobuf_simple_to_struct(*input, field_nums, out_types); + return cudf::jni::release_as_jlong(result); + } + JNI_CATCH(env, 0); +} + +} // extern "C" + + + diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu new file mode 100644 index 0000000000..d81e137a2f --- /dev/null +++ b/src/main/cpp/src/protobuf_simple.cu @@ -0,0 +1,468 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "protobuf_simple.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int WT_VARINT = 0; +constexpr int WT_64BIT = 1; +constexpr int WT_LEN = 2; +constexpr int WT_32BIT = 5; + +__device__ inline bool read_varint(uint8_t const* cur, uint8_t const* end, uint64_t& out, int& bytes) +{ + out = 0; + bytes = 0; + int shift = 0; + while (cur < end && bytes < 10) { + uint8_t b = *cur++; + out |= (static_cast(b & 0x7Fu) << shift); + bytes++; + if ((b & 0x80u) == 0) { return true; } + shift += 7; + } + return false; +} + +__device__ inline bool skip_field(uint8_t const* cur, uint8_t const* end, int wt, uint8_t const*& out_cur) +{ + out_cur = cur; + switch (wt) { + case WT_VARINT: { + uint64_t tmp; + int n; + if (!read_varint(out_cur, end, tmp, n)) return false; + out_cur += n; + return true; + } + case WT_64BIT: + if (end - out_cur < 8) return false; + out_cur += 8; + return true; + case WT_32BIT: + if (end - out_cur < 4) return false; + out_cur += 4; + return true; + case WT_LEN: { + uint64_t len64; + int n; + if (!read_varint(out_cur, end, len64, n)) return false; + out_cur += n; + if (len64 > static_cast(end - out_cur)) return false; + out_cur += static_cast(len64); + return true; + } + default: return false; + } +} + +template +__device__ inline T load_le(uint8_t const* p); + +template <> +__device__ inline uint32_t load_le(uint8_t const* p) +{ + return static_cast(p[0]) | (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); +} + +template <> +__device__ inline uint64_t load_le(uint8_t const* p) +{ + uint64_t v = 0; + #pragma unroll + for (int i = 0; i < 8; ++i) { + v |= (static_cast(p[i]) << (8 * i)); + } + return v; +} + +template +__global__ void extract_varint_kernel(cudf::column_device_view const d_in, + int field_number, + OutT* out, + bool* valid, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + if (in.nullable() && in.is_null(row)) { + valid[row] = false; + return; + } + + // Use sliced child + offsets normalized to the slice start to correctly handle + // list columns with non-zero row offsets (and any child offsets). + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; + // Defensive bounds checks: if offsets are inconsistent, avoid illegal memory access. + if (start < 0 || end < start || end > child.size()) { + *error_flag = 1; + valid[row] = false; + return; + } + uint8_t const* cur = bytes + start; + uint8_t const* stop = bytes + end; + + bool found = false; + OutT value{}; + while (cur < stop) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, stop, key, key_bytes)) { + *error_flag = 1; + break; + } + cur += key_bytes; + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + if (fn == field_number) { + if (wt != WT_VARINT) { + *error_flag = 1; + break; + } + uint64_t v; + int n; + if (!read_varint(cur, stop, v, n)) { + *error_flag = 1; + break; + } + cur += n; + value = static_cast(v); + found = true; + // Continue scanning to allow "last one wins" semantics. + } else { + uint8_t const* next; + if (!skip_field(cur, stop, wt, next)) { + *error_flag = 1; + break; + } + cur = next; + } + } + + if (found) { + out[row] = value; + valid[row] = true; + } else { + valid[row] = false; + } +} + +template +__global__ void extract_fixed_kernel(cudf::column_device_view const d_in, + int field_number, + OutT* out, + bool* valid, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + if (in.nullable() && in.is_null(row)) { + valid[row] = false; + return; + } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; + if (start < 0 || end < start || end > child.size()) { + *error_flag = 1; + valid[row] = false; + return; + } + uint8_t const* cur = bytes + start; + uint8_t const* stop = bytes + end; + + bool found = false; + OutT value{}; + while (cur < stop) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, stop, key, key_bytes)) { + *error_flag = 1; + break; + } + cur += key_bytes; + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + if (fn == field_number) { + if (wt != WT) { + *error_flag = 1; + break; + } + if constexpr (WT == WT_32BIT) { + if (stop - cur < 4) { *error_flag = 1; break; } + uint32_t raw = load_le(cur); + cur += 4; + value = *reinterpret_cast(&raw); + } else { + if (stop - cur < 8) { *error_flag = 1; break; } + uint64_t raw = load_le(cur); + cur += 8; + value = *reinterpret_cast(&raw); + } + found = true; + } else { + uint8_t const* next; + if (!skip_field(cur, stop, wt, next)) { + *error_flag = 1; + break; + } + cur = next; + } + } + + if (found) { + out[row] = value; + valid[row] = true; + } else { + valid[row] = false; + } +} + +__global__ void extract_string_kernel(cudf::column_device_view const d_in, + int field_number, + cudf::strings::detail::string_index_pair* out_pairs, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + if (in.nullable() && in.is_null(row)) { + out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; + return; + } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; + if (start < 0 || end < start || end > child.size()) { + *error_flag = 1; + out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; + return; + } + uint8_t const* cur = bytes + start; + uint8_t const* stop = bytes + end; + + cudf::strings::detail::string_index_pair pair{nullptr, 0}; + while (cur < stop) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, stop, key, key_bytes)) { + *error_flag = 1; + break; + } + cur += key_bytes; + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + if (fn == field_number) { + if (wt != WT_LEN) { + *error_flag = 1; + break; + } + uint64_t len64; + int n; + if (!read_varint(cur, stop, len64, n)) { + *error_flag = 1; + break; + } + cur += n; + if (len64 > static_cast(stop - cur)) { + *error_flag = 1; + break; + } + pair.first = reinterpret_cast(cur); + pair.second = static_cast(len64); + cur += static_cast(len64); + // Continue scanning to allow "last one wins". + } else { + uint8_t const* next; + if (!skip_field(cur, stop, wt, next)) { + *error_flag = 1; + break; + } + cur = next; + } + } + + out_pairs[row] = pair; +} + +inline std::pair make_null_mask_from_valid( + rmm::device_uvector const& valid, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto begin = thrust::make_counting_iterator(0); + auto end = begin + valid.size(); + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { return ptr[i]; }; + return cudf::detail::valid_if(begin, end, pred, stream, mr); +} + +} // namespace + +namespace spark_rapids_jni { + +std::unique_ptr decode_protobuf_simple_to_struct( + cudf::column_view const& binary_input, + std::vector const& field_numbers, + std::vector const& out_types) +{ + CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, + "binary_input must be a LIST column"); + cudf::lists_column_view const in_list(binary_input); + auto const child_type = in_list.child().type().id(); + CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, + "binary_input must be a LIST column"); + CUDF_EXPECTS(field_numbers.size() == out_types.size(), + "field_numbers and out_types must have the same length"); + + auto const stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + auto d_in = cudf::column_device_view::create(binary_input, stream); + auto rows = binary_input.size(); + + // Track parse errors across kernels. + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + std::vector> children; + children.reserve(out_types.size()); + + auto const threads = 256; + auto const blocks = static_cast((rows + threads - 1) / threads); + + for (std::size_t i = 0; i < out_types.size(); ++i) { + auto const fn = field_numbers[i]; + auto const dt = out_types[i]; + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::STRING: { + rmm::device_uvector pairs(rows, stream, mr); + extract_string_kernel<<>>(*d_in, fn, pairs.data(), d_error.data()); + children.push_back(cudf::strings::detail::make_strings_column( + pairs.begin(), pairs.end(), stream, mr)); + break; + } + default: CUDF_FAIL("Unsupported output type for protobuf_simple"); + } + } + + // Check for any parse errors. + int h_error = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); + + // Note: We intentionally do NOT propagate input nulls to the output STRUCT validity. + // The expected semantics for this low-level helper (see ProtobufSimpleTest) are: + // - The STRUCT row is always valid (non-null) + // - Individual children are null if the input message is null or the field is missing + // + // Higher-level Spark expressions can still apply their own null semantics if desired. + rmm::device_buffer struct_mask{0, stream, mr}; + auto const struct_null_count = 0; + + return cudf::make_structs_column(rows, + std::move(children), + struct_null_count, + std::move(struct_mask), + stream, + mr); +} + +} // namespace spark_rapids_jni + + diff --git a/src/main/cpp/src/protobuf_simple.hpp b/src/main/cpp/src/protobuf_simple.hpp new file mode 100644 index 0000000000..ad086b2a32 --- /dev/null +++ b/src/main/cpp/src/protobuf_simple.hpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace spark_rapids_jni { + +/** + * Decode protobuf messages (one message per row) from a LIST column into a STRUCT column. + * + * This is intentionally limited to "simple types" (top-level scalar fields). + * + * Supported output child types: + * - BOOL8, INT32, INT64, FLOAT32, FLOAT64, STRING + * + * @param binary_input LIST column, each row is one protobuf message + * @param field_numbers protobuf field numbers (one per output child) + * @param out_types output cudf data types (one per output child) + * @return STRUCT column with the given children types, with nullability propagated from input rows + */ +std::unique_ptr decode_protobuf_simple_to_struct( + cudf::column_view const& binary_input, + std::vector const& field_numbers, + std::vector const& out_types); + +} // namespace spark_rapids_jni + + + diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java new file mode 100644 index 0000000000..c482659f51 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; +import ai.rapids.cudf.NativeDepsLoader; + +/** + * Simple GPU protobuf decoding utilities. + * + * This is intentionally limited to "simple types" (top-level scalar fields) as a first patch. + * Nested/repeated/map/oneof are out of scope for this API. + */ +public class ProtobufSimple { + static { + NativeDepsLoader.loadNativeDeps(); + } + + /** + * Decode a protobuf message-per-row binary column into a single STRUCT column. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param fieldNumbers protobuf field numbers to decode (one per struct child) + * @param typeIds cudf native type ids (one per struct child) + * @param typeScales cudf decimal scales (unused for simple types; pass 0s) + * @return a cudf STRUCT column where children correspond 1:1 with {@code fieldNumbers}/{@code typeIds}. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] typeScales) { + if (fieldNumbers == null || typeIds == null || typeScales == null) { + throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be non-null"); + } + if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { + throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); + } + long handle = decodeToStruct(binaryInput.getNativeView(), fieldNumbers, typeIds, typeScales); + return new ColumnVector(handle); + } + + private static native long decodeToStruct(long binaryInputView, + int[] fieldNumbers, + int[] typeIds, + int[] typeScales); +} + + + diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java new file mode 100644 index 0000000000..157f0e58db --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java @@ -0,0 +1,94 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.AssertUtils; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.Table; +import org.junit.jupiter.api.Test; + +public class ProtobufSimpleTest { + + private static byte[] encodeVarint(long value) { + long v = value; + byte[] tmp = new byte[10]; + int idx = 0; + while ((v & ~0x7FL) != 0) { + tmp[idx++] = (byte) ((v & 0x7F) | 0x80); + v >>>= 7; + } + tmp[idx++] = (byte) (v & 0x7F); + byte[] out = new byte[idx]; + System.arraycopy(tmp, 0, out, 0, idx); + return out; + } + + private static Byte[] box(byte[] bytes) { + Byte[] out = new Byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + out[i] = bytes[i]; + } + return out; + } + + private static Byte[] concat(Byte[]... parts) { + int len = 0; + for (Byte[] p : parts) len += p.length; + Byte[] out = new Byte[len]; + int pos = 0; + for (Byte[] p : parts) { + System.arraycopy(p, 0, out, pos, p.length); + pos += p.length; + } + return out; + } + + @Test + void decodeVarintAndStringToStruct() { + // message Msg { int64 id = 1; string name = 2; } + // Row0: id=100, name="alice" + Byte[] row0 = concat( + new Byte[]{(byte) 0x08}, // field 1, varint + box(encodeVarint(100)), + new Byte[]{(byte) 0x12}, // field 2, len-delimited + box(encodeVarint(5)), + box("alice".getBytes())); + + // Row1: id=200, name missing + Byte[] row1 = concat( + new Byte[]{(byte) 0x08}, + box(encodeVarint(200))); + + // Row2: null input message + Byte[] row2 = null; + + try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); + ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); + ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); + ColumnVector actualStruct = ProtobufSimple.decodeToStruct( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } +} + + From a319d102e573e9e5e8bd5d49211b099fb1c4c2eb Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 23 Dec 2025 17:20:06 +0800 Subject: [PATCH 002/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufSimpleJni.cpp | 18 +++-- src/main/cpp/src/protobuf_simple.cu | 104 +++++++++++++------------ src/main/cpp/src/protobuf_simple.hpp | 6 +- 3 files changed, 65 insertions(+), 63 deletions(-) diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufSimpleJni.cpp index 56f290d228..18d5c85326 100644 --- a/src/main/cpp/src/ProtobufSimpleJni.cpp +++ b/src/main/cpp/src/ProtobufSimpleJni.cpp @@ -22,8 +22,13 @@ extern "C" { -JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct( - JNIEnv* env, jclass, jlong binary_input_view, jintArray field_numbers, jintArray type_ids, jintArray type_scales) +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, + jclass, + jlong binary_input_view, + jintArray field_numbers, + jintArray type_ids, + jintArray type_scales) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); @@ -37,7 +42,8 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeTo cudf::jni::native_jintArray n_field_numbers(env, field_numbers); cudf::jni::native_jintArray n_type_ids(env, type_ids); cudf::jni::native_jintArray n_type_scales(env, type_scales); - if (n_field_numbers.size() != n_type_ids.size() || n_field_numbers.size() != n_type_scales.size()) { + if (n_field_numbers.size() != n_type_ids.size() || + n_field_numbers.size() != n_type_scales.size()) { JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "fieldNumbers/typeIds/typeScales must be the same length", @@ -51,14 +57,10 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeTo out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); } - auto result = - spark_rapids_jni::decode_protobuf_simple_to_struct(*input, field_nums, out_types); + auto result = spark_rapids_jni::decode_protobuf_simple_to_struct(*input, field_nums, out_types); return cudf::jni::release_as_jlong(result); } JNI_CATCH(env, 0); } } // extern "C" - - - diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu index d81e137a2f..5e1c801d85 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf_simple.cu @@ -16,12 +16,12 @@ #include "protobuf_simple.hpp" -#include #include +#include #include #include -#include #include +#include #include #include #include @@ -41,10 +41,13 @@ constexpr int WT_64BIT = 1; constexpr int WT_LEN = 2; constexpr int WT_32BIT = 5; -__device__ inline bool read_varint(uint8_t const* cur, uint8_t const* end, uint64_t& out, int& bytes) +__device__ inline bool read_varint(uint8_t const* cur, + uint8_t const* end, + uint64_t& out, + int& bytes) { - out = 0; - bytes = 0; + out = 0; + bytes = 0; int shift = 0; while (cur < end && bytes < 10) { uint8_t b = *cur++; @@ -56,7 +59,10 @@ __device__ inline bool read_varint(uint8_t const* cur, uint8_t const* end, uint6 return false; } -__device__ inline bool skip_field(uint8_t const* cur, uint8_t const* end, int wt, uint8_t const*& out_cur) +__device__ inline bool skip_field(uint8_t const* cur, + uint8_t const* end, + int wt, + uint8_t const*& out_cur) { out_cur = cur; switch (wt) { @@ -102,7 +108,7 @@ template <> __device__ inline uint64_t load_le(uint8_t const* p) { uint64_t v = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 8; ++i) { v |= (static_cast(p[i]) << (8 * i)); } @@ -110,11 +116,8 @@ __device__ inline uint64_t load_le(uint8_t const* p) } template -__global__ void extract_varint_kernel(cudf::column_device_view const d_in, - int field_number, - OutT* out, - bool* valid, - int* error_flag) +__global__ void extract_varint_kernel( + cudf::column_device_view const d_in, int field_number, OutT* out, bool* valid, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; @@ -126,18 +129,18 @@ __global__ void extract_varint_kernel(cudf::column_device_view const d_in, // Use sliced child + offsets normalized to the slice start to correctly handle // list columns with non-zero row offsets (and any child offsets). - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; // Defensive bounds checks: if offsets are inconsistent, avoid illegal memory access. if (start < 0 || end < start || end > child.size()) { *error_flag = 1; - valid[row] = false; + valid[row] = false; return; } - uint8_t const* cur = bytes + start; + uint8_t const* cur = bytes + start; uint8_t const* stop = bytes + end; bool found = false; @@ -178,7 +181,7 @@ __global__ void extract_varint_kernel(cudf::column_device_view const d_in, } if (found) { - out[row] = value; + out[row] = value; valid[row] = true; } else { valid[row] = false; @@ -186,11 +189,8 @@ __global__ void extract_varint_kernel(cudf::column_device_view const d_in, } template -__global__ void extract_fixed_kernel(cudf::column_device_view const d_in, - int field_number, - OutT* out, - bool* valid, - int* error_flag) +__global__ void extract_fixed_kernel( + cudf::column_device_view const d_in, int field_number, OutT* out, bool* valid, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; @@ -200,17 +200,17 @@ __global__ void extract_fixed_kernel(cudf::column_device_view const d_in, return; } - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; if (start < 0 || end < start || end > child.size()) { *error_flag = 1; - valid[row] = false; + valid[row] = false; return; } - uint8_t const* cur = bytes + start; + uint8_t const* cur = bytes + start; uint8_t const* stop = bytes + end; bool found = false; @@ -231,12 +231,18 @@ __global__ void extract_fixed_kernel(cudf::column_device_view const d_in, break; } if constexpr (WT == WT_32BIT) { - if (stop - cur < 4) { *error_flag = 1; break; } + if (stop - cur < 4) { + *error_flag = 1; + break; + } uint32_t raw = load_le(cur); cur += 4; value = *reinterpret_cast(&raw); } else { - if (stop - cur < 8) { *error_flag = 1; break; } + if (stop - cur < 8) { + *error_flag = 1; + break; + } uint64_t raw = load_le(cur); cur += 8; value = *reinterpret_cast(&raw); @@ -253,7 +259,7 @@ __global__ void extract_fixed_kernel(cudf::column_device_view const d_in, } if (found) { - out[row] = value; + out[row] = value; valid[row] = true; } else { valid[row] = false; @@ -273,17 +279,17 @@ __global__ void extract_string_kernel(cudf::column_device_view const d_in, return; } - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; if (start < 0 || end < start || end > child.size()) { - *error_flag = 1; + *error_flag = 1; out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; return; } - uint8_t const* cur = bytes + start; + uint8_t const* cur = bytes + start; uint8_t const* stop = bytes + end; cudf::strings::detail::string_index_pair pair{nullptr, 0}; @@ -431,9 +437,10 @@ std::unique_ptr decode_protobuf_simple_to_struct( } case cudf::type_id::STRING: { rmm::device_uvector pairs(rows, stream, mr); - extract_string_kernel<<>>(*d_in, fn, pairs.data(), d_error.data()); - children.push_back(cudf::strings::detail::make_strings_column( - pairs.begin(), pairs.end(), stream, mr)); + extract_string_kernel<<>>( + *d_in, fn, pairs.data(), d_error.data()); + children.push_back( + cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr)); break; } default: CUDF_FAIL("Unsupported output type for protobuf_simple"); @@ -442,7 +449,8 @@ std::unique_ptr decode_protobuf_simple_to_struct( // Check for any parse errors. int h_error = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY( + cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); @@ -455,14 +463,8 @@ std::unique_ptr decode_protobuf_simple_to_struct( rmm::device_buffer struct_mask{0, stream, mr}; auto const struct_null_count = 0; - return cudf::make_structs_column(rows, - std::move(children), - struct_null_count, - std::move(struct_mask), - stream, - mr); + return cudf::make_structs_column( + rows, std::move(children), struct_null_count, std::move(struct_mask), stream, mr); } } // namespace spark_rapids_jni - - diff --git a/src/main/cpp/src/protobuf_simple.hpp b/src/main/cpp/src/protobuf_simple.hpp index ad086b2a32..1f8b57eb0d 100644 --- a/src/main/cpp/src/protobuf_simple.hpp +++ b/src/main/cpp/src/protobuf_simple.hpp @@ -26,7 +26,8 @@ namespace spark_rapids_jni { /** - * Decode protobuf messages (one message per row) from a LIST column into a STRUCT column. + * Decode protobuf messages (one message per row) from a LIST column into a STRUCT + * column. * * This is intentionally limited to "simple types" (top-level scalar fields). * @@ -44,6 +45,3 @@ std::unique_ptr decode_protobuf_simple_to_struct( std::vector const& out_types); } // namespace spark_rapids_jni - - - From 458b5835168cc3cd2fc79038e30295f30da595a3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Dec 2025 10:58:03 +0800 Subject: [PATCH 003/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufSimpleJni.cpp | 6 +- src/main/cpp/src/protobuf_simple.cu | 159 ++++++++++++++++-- src/main/cpp/src/protobuf_simple.hpp | 6 +- .../spark/rapids/jni/ProtobufSimple.java | 34 +++- .../spark/rapids/jni/ProtobufSimpleTest.java | 75 ++++++++- 5 files changed, 254 insertions(+), 26 deletions(-) diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufSimpleJni.cpp index 18d5c85326..abf20e6533 100644 --- a/src/main/cpp/src/ProtobufSimpleJni.cpp +++ b/src/main/cpp/src/ProtobufSimpleJni.cpp @@ -28,7 +28,8 @@ Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, jlong binary_input_view, jintArray field_numbers, jintArray type_ids, - jintArray type_scales) + jintArray type_scales, + jboolean fail_on_errors) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); @@ -51,13 +52,14 @@ Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, } std::vector field_nums(n_field_numbers.begin(), n_field_numbers.end()); + std::vector encodings(n_type_scales.begin(), n_type_scales.end()); std::vector out_types; out_types.reserve(n_type_ids.size()); for (int i = 0; i < n_type_ids.size(); ++i) { out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); } - auto result = spark_rapids_jni::decode_protobuf_simple_to_struct(*input, field_nums, out_types); + auto result = spark_rapids_jni::decode_protobuf_simple_to_struct(*input, field_nums, out_types, encodings, fail_on_errors); return cudf::jni::release_as_jlong(result); } JNI_CATCH(env, 0); diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu index 5e1c801d85..e6371614a7 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf_simple.cu @@ -41,6 +41,18 @@ constexpr int WT_64BIT = 1; constexpr int WT_LEN = 2; constexpr int WT_32BIT = 5; +} // namespace + +namespace spark_rapids_jni { + +constexpr int ENC_DEFAULT = 0; +constexpr int ENC_FIXED = 1; +constexpr int ENC_ZIGZAG = 2; + +} // namespace spark_rapids_jni + +namespace { + __device__ inline bool read_varint(uint8_t const* cur, uint8_t const* end, uint64_t& out, @@ -115,7 +127,7 @@ __device__ inline uint64_t load_le(uint8_t const* p) return v; } -template +template __global__ void extract_varint_kernel( cudf::column_device_view const d_in, int field_number, OutT* out, bool* valid, int* error_flag) { @@ -155,6 +167,10 @@ __global__ void extract_varint_kernel( cur += key_bytes; int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); + if (fn == 0) { + *error_flag = 1; + break; + } if (fn == field_number) { if (wt != WT_VARINT) { *error_flag = 1; @@ -167,6 +183,9 @@ __global__ void extract_varint_kernel( break; } cur += n; + if constexpr (ZigZag) { + v = (v >> 1) ^ (-(v & 1)); + } value = static_cast(v); found = true; // Continue scanning to allow "last one wins" semantics. @@ -225,6 +244,10 @@ __global__ void extract_fixed_kernel( cur += key_bytes; int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); + if (fn == 0) { + *error_flag = 1; + break; + } if (fn == field_number) { if (wt != WT) { *error_flag = 1; @@ -303,6 +326,10 @@ __global__ void extract_string_kernel(cudf::column_device_view const d_in, cur += key_bytes; int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); + if (fn == 0) { + *error_flag = 1; + break; + } if (fn == field_number) { if (wt != WT_LEN) { *error_flag = 1; @@ -354,7 +381,9 @@ namespace spark_rapids_jni { std::unique_ptr decode_protobuf_simple_to_struct( cudf::column_view const& binary_input, std::vector const& field_numbers, - std::vector const& out_types) + std::vector const& out_types, + std::vector const& encodings, + bool fail_on_errors) { CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, "binary_input must be a LIST column"); @@ -364,6 +393,8 @@ std::unique_ptr decode_protobuf_simple_to_struct( "binary_input must be a LIST column"); CUDF_EXPECTS(field_numbers.size() == out_types.size(), "field_numbers and out_types must have the same length"); + CUDF_EXPECTS(encodings.size() == out_types.size(), + "encodings and out_types must have the same length"); auto const stream = cudf::get_default_stream(); auto mr = cudf::get_current_device_resource_ref(); @@ -382,14 +413,19 @@ std::unique_ptr decode_protobuf_simple_to_struct( auto const blocks = static_cast((rows + threads - 1) / threads); for (std::size_t i = 0; i < out_types.size(); ++i) { - auto const fn = field_numbers[i]; - auto const dt = out_types[i]; + auto const fn = field_numbers[i]; + auto const dt = out_types[i]; + auto const enc = encodings[i]; switch (dt.id()) { case cudf::type_id::BOOL8: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); + if (enc == ENC_DEFAULT) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for BOOL8 protobuf field"); + } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); children.push_back( std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); @@ -398,8 +434,35 @@ std::unique_ptr decode_protobuf_simple_to_struct( case cudf::type_id::INT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); + if (enc == ENC_ZIGZAG) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else if (enc == ENC_FIXED) { + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else if (enc == ENC_DEFAULT) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for INT32 protobuf field"); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::UINT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + if (enc == ENC_FIXED) { + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else if (enc == ENC_DEFAULT) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for UINT32 protobuf field"); + } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); children.push_back( std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); @@ -408,8 +471,35 @@ std::unique_ptr decode_protobuf_simple_to_struct( case cudf::type_id::INT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); + if (enc == ENC_ZIGZAG) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else if (enc == ENC_FIXED) { + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else if (enc == ENC_DEFAULT) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for INT64 protobuf field"); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + children.push_back( + std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::UINT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + if (enc == ENC_FIXED) { + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else if (enc == ENC_DEFAULT) { + extract_varint_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for UINT64 protobuf field"); + } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); children.push_back( std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); @@ -418,8 +508,12 @@ std::unique_ptr decode_protobuf_simple_to_struct( case cudf::type_id::FLOAT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); + if (enc == ENC_DEFAULT) { + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for FLOAT32 protobuf field"); + } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); children.push_back( std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); @@ -428,8 +522,12 @@ std::unique_ptr decode_protobuf_simple_to_struct( case cudf::type_id::FLOAT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); + if (enc == ENC_DEFAULT) { + extract_fixed_kernel<<>>( + *d_in, fn, out.data(), valid.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for FLOAT64 protobuf field"); + } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); children.push_back( std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); @@ -437,12 +535,37 @@ std::unique_ptr decode_protobuf_simple_to_struct( } case cudf::type_id::STRING: { rmm::device_uvector pairs(rows, stream, mr); - extract_string_kernel<<>>( - *d_in, fn, pairs.data(), d_error.data()); + if (enc == ENC_DEFAULT) { + extract_string_kernel<<>>( + *d_in, fn, pairs.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for STRING protobuf field"); + } children.push_back( cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr)); break; } + case cudf::type_id::LIST: { + rmm::device_uvector pairs(rows, stream, mr); + if (enc == ENC_DEFAULT) { + extract_string_kernel<<>>( + *d_in, fn, pairs.data(), d_error.data()); + } else { + CUDF_FAIL("Unsupported encoding for LIST protobuf field"); + } + auto strings = cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); + auto const null_count = strings->null_count(); + auto contents = strings->release(); + auto null_mask = contents.null_mask ? std::move(*contents.null_mask) : rmm::device_buffer{0, stream, mr}; + children.push_back(cudf::make_lists_column(rows, + std::move(contents.children[0]), + std::move(contents.children[1]), + null_count, + std::move(null_mask), + stream, + mr)); + break; + } default: CUDF_FAIL("Unsupported output type for protobuf_simple"); } } @@ -452,7 +575,9 @@ std::unique_ptr decode_protobuf_simple_to_struct( CUDF_CUDA_TRY( cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); + if (fail_on_errors) { + CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); + } // Note: We intentionally do NOT propagate input nulls to the output STRUCT validity. // The expected semantics for this low-level helper (see ProtobufSimpleTest) are: diff --git a/src/main/cpp/src/protobuf_simple.hpp b/src/main/cpp/src/protobuf_simple.hpp index 1f8b57eb0d..c551d9adf8 100644 --- a/src/main/cpp/src/protobuf_simple.hpp +++ b/src/main/cpp/src/protobuf_simple.hpp @@ -37,11 +37,15 @@ namespace spark_rapids_jni { * @param binary_input LIST column, each row is one protobuf message * @param field_numbers protobuf field numbers (one per output child) * @param out_types output cudf data types (one per output child) + * @param encodings encoding type for each field (0=default, 1=fixed, 2=zigzag) + * @param fail_on_errors whether to throw on malformed messages * @return STRUCT column with the given children types, with nullability propagated from input rows */ std::unique_ptr decode_protobuf_simple_to_struct( cudf::column_view const& binary_input, std::vector const& field_numbers, - std::vector const& out_types); + std::vector const& out_types, + std::vector const& encodings, + bool fail_on_errors); } // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java index c482659f51..5a01b28cce 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java @@ -31,33 +31,61 @@ public class ProtobufSimple { NativeDepsLoader.loadNativeDeps(); } + public static final int ENC_DEFAULT = 0; + public static final int ENC_FIXED = 1; + public static final int ENC_ZIGZAG = 2; + /** * Decode a protobuf message-per-row binary column into a single STRUCT column. * * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. * @param fieldNumbers protobuf field numbers to decode (one per struct child) * @param typeIds cudf native type ids (one per struct child) - * @param typeScales cudf decimal scales (unused for simple types; pass 0s) + * @param typeScales encoding info or decimal scales: + * - For non-decimal types, this is the encoding: 0=default, 1=fixed, 2=zigzag. + * - For decimal types, this is the scale (currently unsupported). * @return a cudf STRUCT column where children correspond 1:1 with {@code fieldNumbers}/{@code typeIds}. */ public static ColumnVector decodeToStruct(ColumnView binaryInput, int[] fieldNumbers, int[] typeIds, int[] typeScales) { + return decodeToStruct(binaryInput, fieldNumbers, typeIds, typeScales, true); + } + + /** + * Decode a protobuf message-per-row binary column into a single STRUCT column. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param fieldNumbers protobuf field numbers to decode (one per struct child) + * @param typeIds cudf native type ids (one per struct child) + * @param typeScales encoding info or decimal scales: + * - For non-decimal types, this is the encoding: 0=default, 1=fixed, 2=zigzag. + * - For decimal types, this is the scale (currently unsupported). + * @param failOnErrors if true, throw an exception on malformed protobuf messages. + * If false, return nulls for fields that cannot be parsed. + * @return a cudf STRUCT column where children correspond 1:1 with {@code fieldNumbers}/{@code typeIds}. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] typeScales, + boolean failOnErrors) { if (fieldNumbers == null || typeIds == null || typeScales == null) { throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be non-null"); } if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); } - long handle = decodeToStruct(binaryInput.getNativeView(), fieldNumbers, typeIds, typeScales); + long handle = decodeToStruct(binaryInput.getNativeView(), fieldNumbers, typeIds, typeScales, failOnErrors); return new ColumnVector(handle); } private static native long decodeToStruct(long binaryInputView, int[] fieldNumbers, int[] typeIds, - int[] typeScales); + int[] typeScales, + boolean failOnErrors); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java index 157f0e58db..71300f2fd5 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java @@ -19,9 +19,14 @@ import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector.*; import ai.rapids.cudf.Table; import org.junit.jupiter.api.Test; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + public class ProtobufSimpleTest { private static byte[] encodeVarint(long value) { @@ -38,7 +43,20 @@ private static byte[] encodeVarint(long value) { return out; } + private static long zigzagEncode(long n) { + return (n << 1) ^ (n >> 63); + } + + private static byte[] encodeFixed32(int v) { + return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(v).array(); + } + + private static byte[] encodeFixed64(long v) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(v).array(); + } + private static Byte[] box(byte[] bytes) { + if (bytes == null) return null; Byte[] out = new Byte[bytes.length]; for (int i = 0; i < bytes.length; i++) { out[i] = bytes[i]; @@ -48,12 +66,14 @@ private static Byte[] box(byte[] bytes) { private static Byte[] concat(Byte[]... parts) { int len = 0; - for (Byte[] p : parts) len += p.length; + for (Byte[] p : parts) if (p != null) len += p.length; Byte[] out = new Byte[len]; int pos = 0; for (Byte[] p : parts) { - System.arraycopy(p, 0, out, pos, p.length); - pos += p.length; + if (p != null) { + System.arraycopy(p, 0, out, pos, p.length); + pos += p.length; + } } return out; } @@ -89,6 +109,55 @@ void decodeVarintAndStringToStruct() { AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); } } + + @Test + void decodeMoreTypes() { + // message Msg { + // uint32 u32 = 1; + // sint64 s64 = 2; + // fixed32 f32 = 3; + // bytes b = 4; + // } + Byte[] row0 = concat( + new Byte[]{(byte) 0x08}, // field 1, varint + box(encodeVarint(4000000000L)), + new Byte[]{(byte) 0x10}, // field 2, varint + box(encodeVarint(zigzagEncode(-1234567890123L))), + new Byte[]{(byte) 0x1d}, // field 3, fixed32 + box(encodeFixed32(12345)), + new Byte[]{(byte) 0x22}, // field 4, len-delimited + box(encodeVarint(3)), + box(new byte[]{1, 2, 3})); + + try (Table input = new Table.TestBuilder().column(row0).build(); + ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); // cuDF doesn't have boxed UInt32 easily, use Longs for test if needed, but we want native id + // Wait, I'll use direct values to avoid Boxing issues with UInt32 + ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); + ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); + ColumnVector expectedB = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.INT8)), + Arrays.asList((byte) 1, (byte) 2, (byte) 3)); + ColumnVector actualStruct = ProtobufSimple.decodeToStruct( + input.getColumn(0), + new int[]{1, 2, 3, 4}, + new int[]{ + DType.UINT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.LIST.getTypeId().getNativeId()}, + new int[]{ + ProtobufSimple.ENC_DEFAULT, + ProtobufSimple.ENC_ZIGZAG, + ProtobufSimple.ENC_FIXED, + ProtobufSimple.ENC_DEFAULT})) { + // For UINT32, expectedU32 from fromBoxedLongs will be INT64. + // I should use makeColumn to get exactly the right types for comparison. + try (ColumnVector expectedU32Correct = expectedU32.castTo(DType.UINT32); + ColumnVector expectedStructCorrect = ColumnVector.makeStruct(expectedU32Correct, expectedS64, expectedF32, expectedB)) { + AssertUtils.assertStructColumnsAreEqual(expectedStructCorrect, actualStruct); + } + } + } } From 5c1bbf43765821fc0ea45a8a45d1a97485c40267 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Dec 2025 11:41:51 +0800 Subject: [PATCH 004/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufSimpleJni.cpp | 3 ++- src/main/cpp/src/protobuf_simple.cu | 10 +++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufSimpleJni.cpp index abf20e6533..6b63efa510 100644 --- a/src/main/cpp/src/ProtobufSimpleJni.cpp +++ b/src/main/cpp/src/ProtobufSimpleJni.cpp @@ -59,7 +59,8 @@ Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); } - auto result = spark_rapids_jni::decode_protobuf_simple_to_struct(*input, field_nums, out_types, encodings, fail_on_errors); + auto result = spark_rapids_jni::decode_protobuf_simple_to_struct( + *input, field_nums, out_types, encodings, fail_on_errors); return cudf::jni::release_as_jlong(result); } JNI_CATCH(env, 0); diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu index e6371614a7..06426f52ec 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf_simple.cu @@ -183,9 +183,7 @@ __global__ void extract_varint_kernel( break; } cur += n; - if constexpr (ZigZag) { - v = (v >> 1) ^ (-(v & 1)); - } + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } value = static_cast(v); found = true; // Continue scanning to allow "last one wins" semantics. @@ -553,10 +551,12 @@ std::unique_ptr decode_protobuf_simple_to_struct( } else { CUDF_FAIL("Unsupported encoding for LIST protobuf field"); } - auto strings = cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); + auto strings = + cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); auto const null_count = strings->null_count(); auto contents = strings->release(); - auto null_mask = contents.null_mask ? std::move(*contents.null_mask) : rmm::device_buffer{0, stream, mr}; + auto null_mask = + contents.null_mask ? std::move(*contents.null_mask) : rmm::device_buffer{0, stream, mr}; children.push_back(cudf::make_lists_column(rows, std::move(contents.children[0]), std::move(contents.children[1]), From 0445daacfd1a33d20d41f3d07fb91e492a908a08 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Dec 2025 13:36:39 +0800 Subject: [PATCH 005/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufSimpleJni.cpp | 13 +++- src/main/cpp/src/protobuf_simple.cu | 62 +++++++++++-------- src/main/cpp/src/protobuf_simple.hpp | 19 +++++- .../spark/rapids/jni/ProtobufSimple.java | 18 ++++-- .../spark/rapids/jni/ProtobufSimpleTest.java | 2 - 5 files changed, 78 insertions(+), 36 deletions(-) diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufSimpleJni.cpp index 6b63efa510..ab52841997 100644 --- a/src/main/cpp/src/ProtobufSimpleJni.cpp +++ b/src/main/cpp/src/ProtobufSimpleJni.cpp @@ -19,6 +19,7 @@ #include "protobuf_simple.hpp" #include +#include extern "C" { @@ -56,7 +57,17 @@ Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, std::vector out_types; out_types.reserve(n_type_ids.size()); for (int i = 0; i < n_type_ids.size(); ++i) { - out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); + // For protobuf simple decoding, typeScales contains encoding info (0=default, 1=fixed, + // 2=zigzag) not decimal scales. For non-decimal types, scale should be 0. Decimal types are + // not currently supported in protobuf simple decoder. + auto type_id = static_cast(n_type_ids[i]); + if (cudf::is_fixed_point(cudf::data_type{type_id})) { + // For decimal types, use the scale from typeScales (though currently unsupported) + out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); + } else { + // For non-decimal types, scale is always 0; typeScales contains encoding info + out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], 0)); + } } auto result = spark_rapids_jni::decode_protobuf_simple_to_struct( diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu index 06426f52ec..3222152c1c 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf_simple.cu @@ -98,8 +98,10 @@ __device__ inline bool skip_field(uint8_t const* cur, int n; if (!read_varint(out_cur, end, len64, n)) return false; out_cur += n; - if (len64 > static_cast(end - out_cur)) return false; - out_cur += static_cast(len64); + // Check for both buffer overflow and int overflow + if (len64 > static_cast(end - out_cur) || len64 > static_cast(INT_MAX)) + return false; + out_cur += static_cast(len64); return true; } default: return false; @@ -148,8 +150,8 @@ __global__ void extract_varint_kernel( auto end = in.offset_at(row + 1) - base; // Defensive bounds checks: if offsets are inconsistent, avoid illegal memory access. if (start < 0 || end < start || end > child.size()) { - *error_flag = 1; - valid[row] = false; + atomicExch(error_flag, 1); + valid[row] = false; return; } uint8_t const* cur = bytes + start; @@ -161,25 +163,25 @@ __global__ void extract_varint_kernel( uint64_t key; int key_bytes; if (!read_varint(cur, stop, key, key_bytes)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur += key_bytes; int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); if (fn == 0) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } if (fn == field_number) { if (wt != WT_VARINT) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } uint64_t v; int n; if (!read_varint(cur, stop, v, n)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur += n; @@ -190,7 +192,7 @@ __global__ void extract_varint_kernel( } else { uint8_t const* next; if (!skip_field(cur, stop, wt, next)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur = next; @@ -223,8 +225,8 @@ __global__ void extract_fixed_kernel( auto start = in.offset_at(row) - base; auto end = in.offset_at(row + 1) - base; if (start < 0 || end < start || end > child.size()) { - *error_flag = 1; - valid[row] = false; + atomicExch(error_flag, 1); + valid[row] = false; return; } uint8_t const* cur = bytes + start; @@ -236,43 +238,45 @@ __global__ void extract_fixed_kernel( uint64_t key; int key_bytes; if (!read_varint(cur, stop, key, key_bytes)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur += key_bytes; int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); if (fn == 0) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } if (fn == field_number) { if (wt != WT) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } if constexpr (WT == WT_32BIT) { if (stop - cur < 4) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } uint32_t raw = load_le(cur); cur += 4; - value = *reinterpret_cast(&raw); + // Use memcpy to avoid undefined behavior from type punning + memcpy(&value, &raw, sizeof(value)); } else { if (stop - cur < 8) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } uint64_t raw = load_le(cur); cur += 8; - value = *reinterpret_cast(&raw); + // Use memcpy to avoid undefined behavior from type punning + memcpy(&value, &raw, sizeof(value)); } found = true; } else { uint8_t const* next; if (!skip_field(cur, stop, wt, next)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur = next; @@ -318,40 +322,41 @@ __global__ void extract_string_kernel(cudf::column_device_view const d_in, uint64_t key; int key_bytes; if (!read_varint(cur, stop, key, key_bytes)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur += key_bytes; int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); if (fn == 0) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } if (fn == field_number) { if (wt != WT_LEN) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } uint64_t len64; int n; if (!read_varint(cur, stop, len64, n)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur += n; - if (len64 > static_cast(stop - cur)) { - *error_flag = 1; + // Check for both buffer overflow and int overflow + if (len64 > static_cast(stop - cur) || len64 > static_cast(INT_MAX)) { + atomicExch(error_flag, 1); break; } pair.first = reinterpret_cast(cur); pair.second = static_cast(len64); - cur += static_cast(len64); + cur += static_cast(len64); // Continue scanning to allow "last one wins". } else { uint8_t const* next; if (!skip_field(cur, stop, wt, next)) { - *error_flag = 1; + atomicExch(error_flag, 1); break; } cur = next; @@ -570,6 +575,9 @@ std::unique_ptr decode_protobuf_simple_to_struct( } } + // Check for kernel launch errors + CUDF_CUDA_TRY(cudaPeekAtLastError()); + // Check for any parse errors. int h_error = 0; CUDF_CUDA_TRY( diff --git a/src/main/cpp/src/protobuf_simple.hpp b/src/main/cpp/src/protobuf_simple.hpp index c551d9adf8..14bdfe4352 100644 --- a/src/main/cpp/src/protobuf_simple.hpp +++ b/src/main/cpp/src/protobuf_simple.hpp @@ -31,8 +31,23 @@ namespace spark_rapids_jni { * * This is intentionally limited to "simple types" (top-level scalar fields). * - * Supported output child types: - * - BOOL8, INT32, INT64, FLOAT32, FLOAT64, STRING + * Supported output child types (cudf dtypes) and corresponding protobuf field types: + * - BOOL8 : protobuf `bool` (varint wire type) + * - INT32 : protobuf `int32`, `sint32` (with zigzag), `fixed32`/`sfixed32` (with fixed encoding) + * - UINT32 : protobuf `uint32`, `fixed32` (with fixed encoding) + * - INT64 : protobuf `int64`, `sint64` (with zigzag), `fixed64`/`sfixed64` (with fixed encoding) + * - UINT64 : protobuf `uint64`, `fixed64` (with fixed encoding) + * - FLOAT32 : protobuf `float` (fixed32 wire type) + * - FLOAT64 : protobuf `double` (fixed64 wire type) + * - STRING : protobuf `string` (length-delimited wire type, UTF-8 text) + * - LIST : protobuf `bytes` (length-delimited wire type, raw bytes as LIST) + * + * Integer handling: + * - For standard varint-encoded fields (`int32`, `int64`, `uint32`, `uint64`), use encoding=0. + * - For zigzag-encoded signed fields (`sint32`, `sint64`), use encoding=2. + * - For fixed-width fields (`fixed32`, `fixed64`, `sfixed32`, `sfixed64`), use encoding=1. + * + * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. * * @param binary_input LIST column, each row is one protobuf message * @param field_numbers protobuf field numbers (one per output child) diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java index 5a01b28cce..410ee29c3e 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java @@ -23,8 +23,20 @@ /** * Simple GPU protobuf decoding utilities. * - * This is intentionally limited to "simple types" (top-level scalar fields) as a first patch. - * Nested/repeated/map/oneof are out of scope for this API. + * This API is intentionally limited to "simple types", i.e., top-level scalar fields whose + * values can be represented by a single cuDF scalar type. Supported protobuf field types + * include scalar fields using the standard protobuf wire encodings: + *
    + *
  • VARINT: {@code int32}, {@code int64}, {@code uint32}, {@code uint64}, {@code bool}
  • + *
  • ZIGZAG VARINT (encoding=2): {@code sint32}, {@code sint64}
  • + *
  • FIXED32 (encoding=1): {@code fixed32}, {@code sfixed32}, {@code float}
  • + *
  • FIXED64 (encoding=1): {@code fixed64}, {@code sfixed64}, {@code double}
  • + *
  • LENGTH_DELIMITED: {@code string}, {@code bytes}
  • + *
+ * Each decoded field becomes a child column of the resulting STRUCT, with its cuDF type + * specified via the corresponding {@code typeIds} entry. + *

+ * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. */ public class ProtobufSimple { static { @@ -88,5 +100,3 @@ private static native long decodeToStruct(long binaryInputView, boolean failOnErrors); } - - diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java index 71300f2fd5..db3d0e19e9 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java @@ -159,5 +159,3 @@ void decodeMoreTypes() { } } } - - From 8ddbf9606830bf758a1f3e46b6ae6a1e372dd89c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 25 Dec 2025 15:39:05 +0800 Subject: [PATCH 006/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_simple.cu | 5 ++++- .../nvidia/spark/rapids/jni/ProtobufSimple.java | 17 +++++++++++++++++ .../spark/rapids/jni/ProtobufSimpleTest.java | 7 +++---- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu index 3222152c1c..e987729e80 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf_simple.cu @@ -310,7 +310,7 @@ __global__ void extract_string_kernel(cudf::column_device_view const d_in, auto start = in.offset_at(row) - base; auto end = in.offset_at(row + 1) - base; if (start < 0 || end < start || end > child.size()) { - *error_flag = 1; + atomicExch(error_flag, 1); out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; return; } @@ -549,6 +549,9 @@ std::unique_ptr decode_protobuf_simple_to_struct( break; } case cudf::type_id::LIST: { + // For protobuf `bytes` fields: we reuse the string extraction kernel to get the + // length-delimited raw bytes. The resulting strings column is then re-interpreted as + // LIST by extracting its internal offsets and char data (which is just raw bytes). rmm::device_uvector pairs(rows, stream, mr); if (enc == ENC_DEFAULT) { extract_string_kernel<<>>( diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java index 410ee29c3e..ba1443120e 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java @@ -89,6 +89,23 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); } + // Validate field numbers are positive (protobuf field numbers must be 1-536870911) + for (int i = 0; i < fieldNumbers.length; i++) { + if (fieldNumbers[i] <= 0) { + throw new IllegalArgumentException( + "Invalid field number at index " + i + ": " + fieldNumbers[i] + + " (field numbers must be positive)"); + } + } + // Validate encoding values are within valid range + for (int i = 0; i < typeScales.length; i++) { + int enc = typeScales[i]; + if (enc < ENC_DEFAULT || enc > ENC_ZIGZAG) { + throw new IllegalArgumentException( + "Invalid encoding value at index " + i + ": " + enc + + " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); + } + } long handle = decodeToStruct(binaryInput.getNativeView(), fieldNumbers, typeIds, typeScales, failOnErrors); return new ColumnVector(handle); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java index db3d0e19e9..6cf17e2203 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java @@ -130,8 +130,8 @@ void decodeMoreTypes() { box(new byte[]{1, 2, 3})); try (Table input = new Table.TestBuilder().column(row0).build(); - ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); // cuDF doesn't have boxed UInt32 easily, use Longs for test if needed, but we want native id - // Wait, I'll use direct values to avoid Boxing issues with UInt32 + // Use fromBoxedLongs then cast to UINT32 since cuDF Java lacks direct UINT32 factory + ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); ColumnVector expectedB = ColumnVector.fromLists( @@ -150,8 +150,7 @@ void decodeMoreTypes() { ProtobufSimple.ENC_ZIGZAG, ProtobufSimple.ENC_FIXED, ProtobufSimple.ENC_DEFAULT})) { - // For UINT32, expectedU32 from fromBoxedLongs will be INT64. - // I should use makeColumn to get exactly the right types for comparison. + // Cast expectedU32 from INT64 to UINT32 to match the actual output type try (ColumnVector expectedU32Correct = expectedU32.castTo(DType.UINT32); ColumnVector expectedStructCorrect = ColumnVector.makeStruct(expectedU32Correct, expectedS64, expectedF32, expectedB)) { AssertUtils.assertStructColumnsAreEqual(expectedStructCorrect, actualStruct); From c9eea596df0af7e4fc16391443c02eaa873222f6 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 4 Jan 2026 16:11:26 +0800 Subject: [PATCH 007/107] copyrights Signed-off-by: Haoyang Li --- src/main/cpp/CMakeLists.txt | 2 +- src/main/cpp/src/ProtobufSimpleJni.cpp | 2 +- src/main/cpp/src/protobuf_simple.cu | 2 +- src/main/cpp/src/protobuf_simple.hpp | 2 +- src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java | 2 +- .../java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 94012c370c..cf70111521 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2025, NVIDIA CORPORATION. +# Copyright (c) 2022-2026, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufSimpleJni.cpp index ab52841997..2e26be7206 100644 --- a/src/main/cpp/src/ProtobufSimpleJni.cpp +++ b/src/main/cpp/src/ProtobufSimpleJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf_simple.cu index e987729e80..6bebc3b97d 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf_simple.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/main/cpp/src/protobuf_simple.hpp b/src/main/cpp/src/protobuf_simple.hpp index 14bdfe4352..439d7c20ca 100644 --- a/src/main/cpp/src/protobuf_simple.hpp +++ b/src/main/cpp/src/protobuf_simple.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java index ba1443120e..540af26651 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java index 6cf17e2203..c142b21b87 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 617471dbc2803821b2527bf8960f93ac23eb58b2 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 9 Jan 2026 15:16:29 +0800 Subject: [PATCH 008/107] update, added more tests Signed-off-by: Haoyang Li --- src/main/cpp/CMakeLists.txt | 4 +- ...{ProtobufSimpleJni.cpp => ProtobufJni.cpp} | 10 +- .../src/{protobuf_simple.cu => protobuf.cu} | 49 +- .../src/{protobuf_simple.hpp => protobuf.hpp} | 7 +- .../{ProtobufSimple.java => Protobuf.java} | 6 +- .../spark/rapids/jni/ProtobufSimpleTest.java | 160 --- .../nvidia/spark/rapids/jni/ProtobufTest.java | 1039 +++++++++++++++++ 7 files changed, 1093 insertions(+), 182 deletions(-) rename src/main/cpp/src/{ProtobufSimpleJni.cpp => ProtobufJni.cpp} (90%) rename src/main/cpp/src/{protobuf_simple.cu => protobuf.cu} (92%) rename src/main/cpp/src/{protobuf_simple.hpp => protobuf.hpp} (89%) rename src/main/java/com/nvidia/spark/rapids/jni/{ProtobufSimple.java => Protobuf.java} (97%) delete mode 100644 src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java create mode 100644 src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 8b8e9a359f..034f5c3601 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -207,7 +207,7 @@ add_library( src/NativeParquetJni.cpp src/NumberConverterJni.cpp src/ParseURIJni.cpp - src/ProtobufSimpleJni.cpp + src/ProtobufJni.cpp src/RegexRewriteUtilsJni.cpp src/RowConversionJni.cpp src/SparkResourceAdaptorJni.cpp @@ -255,7 +255,7 @@ add_library( src/multiply.cu src/number_converter.cu src/parse_uri.cu - src/protobuf_simple.cu + src/protobuf.cu src/regex_rewrite_utils.cu src/row_conversion.cu src/round_float.cu diff --git a/src/main/cpp/src/ProtobufSimpleJni.cpp b/src/main/cpp/src/ProtobufJni.cpp similarity index 90% rename from src/main/cpp/src/ProtobufSimpleJni.cpp rename to src/main/cpp/src/ProtobufJni.cpp index ab52841997..ecf70a5e5c 100644 --- a/src/main/cpp/src/ProtobufSimpleJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -16,7 +16,7 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" -#include "protobuf_simple.hpp" +#include "protobuf.hpp" #include #include @@ -24,7 +24,7 @@ extern "C" { JNIEXPORT jlong JNICALL -Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, +Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, jclass, jlong binary_input_view, jintArray field_numbers, @@ -57,9 +57,9 @@ Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, std::vector out_types; out_types.reserve(n_type_ids.size()); for (int i = 0; i < n_type_ids.size(); ++i) { - // For protobuf simple decoding, typeScales contains encoding info (0=default, 1=fixed, + // For protobuf decoding, typeScales contains encoding info (0=default, 1=fixed, // 2=zigzag) not decimal scales. For non-decimal types, scale should be 0. Decimal types are - // not currently supported in protobuf simple decoder. + // not currently supported in protobuf decoder. auto type_id = static_cast(n_type_ids[i]); if (cudf::is_fixed_point(cudf::data_type{type_id})) { // For decimal types, use the scale from typeScales (though currently unsupported) @@ -70,7 +70,7 @@ Java_com_nvidia_spark_rapids_jni_ProtobufSimple_decodeToStruct(JNIEnv* env, } } - auto result = spark_rapids_jni::decode_protobuf_simple_to_struct( + auto result = spark_rapids_jni::decode_protobuf_to_struct( *input, field_nums, out_types, encodings, fail_on_errors); return cudf::jni::release_as_jlong(result); } diff --git a/src/main/cpp/src/protobuf_simple.cu b/src/main/cpp/src/protobuf.cu similarity index 92% rename from src/main/cpp/src/protobuf_simple.cu rename to src/main/cpp/src/protobuf.cu index e987729e80..a16d78b757 100644 --- a/src/main/cpp/src/protobuf_simple.cu +++ b/src/main/cpp/src/protobuf.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "protobuf_simple.hpp" +#include "protobuf.hpp" #include #include @@ -63,6 +63,11 @@ __device__ inline bool read_varint(uint8_t const* cur, int shift = 0; while (cur < end && bytes < 10) { uint8_t b = *cur++; + // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid + // since we can only fit 1 more bit into uint64_t + if (bytes == 9 && (b & 0xFE) != 0) { + return false; // Invalid: 10th byte has more than 1 significant bit + } out |= (static_cast(b & 0x7Fu) << shift); bytes++; if ((b & 0x80u) == 0) { return true; } @@ -381,7 +386,7 @@ inline std::pair make_null_mask_from_valid( namespace spark_rapids_jni { -std::unique_ptr decode_protobuf_simple_to_struct( +std::unique_ptr decode_protobuf_to_struct( cudf::column_view const& binary_input, std::vector const& field_numbers, std::vector const& out_types, @@ -561,20 +566,46 @@ std::unique_ptr decode_protobuf_simple_to_struct( } auto strings = cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); + + // Use strings_column_view to get the underlying data + cudf::strings_column_view scv(*strings); auto const null_count = strings->null_count(); - auto contents = strings->release(); - auto null_mask = - contents.null_mask ? std::move(*contents.null_mask) : rmm::device_buffer{0, stream, mr}; + + // Get offsets - need to copy since we can't take ownership from a view + auto offsets_col = std::make_unique(scv.offsets(), stream, mr); + + // Get chars data as INT8 column + auto chars_data = scv.chars_begin(stream); + auto chars_size = static_cast(scv.chars_size(stream)); + rmm::device_uvector chars_vec(chars_size, stream, mr); + if (chars_size > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(chars_vec.data(), + chars_data, + chars_size, + cudaMemcpyDeviceToDevice, + stream.value())); + } + // Create INT8 column from chars data + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::INT8}, + chars_size, + chars_vec.release(), + rmm::device_buffer{}, // no null mask for chars + 0); // no nulls + + // Get null mask + auto null_mask = cudf::copy_bitmask(*strings, stream, mr); + children.push_back(cudf::make_lists_column(rows, - std::move(contents.children[0]), - std::move(contents.children[1]), + std::move(offsets_col), + std::move(child_col), null_count, std::move(null_mask), stream, mr)); break; } - default: CUDF_FAIL("Unsupported output type for protobuf_simple"); + default: CUDF_FAIL("Unsupported output type for protobuf decoder"); } } @@ -591,7 +622,7 @@ std::unique_ptr decode_protobuf_simple_to_struct( } // Note: We intentionally do NOT propagate input nulls to the output STRUCT validity. - // The expected semantics for this low-level helper (see ProtobufSimpleTest) are: + // The expected semantics for this low-level helper (see ProtobufTest) are: // - The STRUCT row is always valid (non-null) // - Individual children are null if the input message is null or the field is missing // diff --git a/src/main/cpp/src/protobuf_simple.hpp b/src/main/cpp/src/protobuf.hpp similarity index 89% rename from src/main/cpp/src/protobuf_simple.hpp rename to src/main/cpp/src/protobuf.hpp index 14bdfe4352..7932e4d97e 100644 --- a/src/main/cpp/src/protobuf_simple.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -29,7 +29,7 @@ namespace spark_rapids_jni { * Decode protobuf messages (one message per row) from a LIST column into a STRUCT * column. * - * This is intentionally limited to "simple types" (top-level scalar fields). + * This is intentionally limited to top-level scalar fields. * * Supported output child types (cudf dtypes) and corresponding protobuf field types: * - BOOL8 : protobuf `bool` (varint wire type) @@ -54,9 +54,10 @@ namespace spark_rapids_jni { * @param out_types output cudf data types (one per output child) * @param encodings encoding type for each field (0=default, 1=fixed, 2=zigzag) * @param fail_on_errors whether to throw on malformed messages - * @return STRUCT column with the given children types, with nullability propagated from input rows + * @return STRUCT column with the given children types; the STRUCT itself is always non-null, + * and individual child fields may be null when input message is null or field is missing */ -std::unique_ptr decode_protobuf_simple_to_struct( +std::unique_ptr decode_protobuf_to_struct( cudf::column_view const& binary_input, std::vector const& field_numbers, std::vector const& out_types, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java similarity index 97% rename from src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java rename to src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index ba1443120e..13a03e9808 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -21,9 +21,9 @@ import ai.rapids.cudf.NativeDepsLoader; /** - * Simple GPU protobuf decoding utilities. + * GPU protobuf decoding utilities. * - * This API is intentionally limited to "simple types", i.e., top-level scalar fields whose + * This API is intentionally limited to top-level scalar fields whose * values can be represented by a single cuDF scalar type. Supported protobuf field types * include scalar fields using the standard protobuf wire encodings: *

    @@ -38,7 +38,7 @@ *

    * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. */ -public class ProtobufSimple { +public class Protobuf { static { NativeDepsLoader.loadNativeDeps(); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java deleted file mode 100644 index 6cf17e2203..0000000000 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.nvidia.spark.rapids.jni; - -import ai.rapids.cudf.AssertUtils; -import ai.rapids.cudf.ColumnVector; -import ai.rapids.cudf.DType; -import ai.rapids.cudf.HostColumnVector.*; -import ai.rapids.cudf.Table; -import org.junit.jupiter.api.Test; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.Arrays; - -public class ProtobufSimpleTest { - - private static byte[] encodeVarint(long value) { - long v = value; - byte[] tmp = new byte[10]; - int idx = 0; - while ((v & ~0x7FL) != 0) { - tmp[idx++] = (byte) ((v & 0x7F) | 0x80); - v >>>= 7; - } - tmp[idx++] = (byte) (v & 0x7F); - byte[] out = new byte[idx]; - System.arraycopy(tmp, 0, out, 0, idx); - return out; - } - - private static long zigzagEncode(long n) { - return (n << 1) ^ (n >> 63); - } - - private static byte[] encodeFixed32(int v) { - return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(v).array(); - } - - private static byte[] encodeFixed64(long v) { - return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(v).array(); - } - - private static Byte[] box(byte[] bytes) { - if (bytes == null) return null; - Byte[] out = new Byte[bytes.length]; - for (int i = 0; i < bytes.length; i++) { - out[i] = bytes[i]; - } - return out; - } - - private static Byte[] concat(Byte[]... parts) { - int len = 0; - for (Byte[] p : parts) if (p != null) len += p.length; - Byte[] out = new Byte[len]; - int pos = 0; - for (Byte[] p : parts) { - if (p != null) { - System.arraycopy(p, 0, out, pos, p.length); - pos += p.length; - } - } - return out; - } - - @Test - void decodeVarintAndStringToStruct() { - // message Msg { int64 id = 1; string name = 2; } - // Row0: id=100, name="alice" - Byte[] row0 = concat( - new Byte[]{(byte) 0x08}, // field 1, varint - box(encodeVarint(100)), - new Byte[]{(byte) 0x12}, // field 2, len-delimited - box(encodeVarint(5)), - box("alice".getBytes())); - - // Row1: id=200, name missing - Byte[] row1 = concat( - new Byte[]{(byte) 0x08}, - box(encodeVarint(200))); - - // Row2: null input message - Byte[] row2 = null; - - try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); - ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); - ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); - ColumnVector actualStruct = ProtobufSimple.decodeToStruct( - input.getColumn(0), - new int[]{1, 2}, - new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, - new int[]{0, 0})) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); - } - } - - @Test - void decodeMoreTypes() { - // message Msg { - // uint32 u32 = 1; - // sint64 s64 = 2; - // fixed32 f32 = 3; - // bytes b = 4; - // } - Byte[] row0 = concat( - new Byte[]{(byte) 0x08}, // field 1, varint - box(encodeVarint(4000000000L)), - new Byte[]{(byte) 0x10}, // field 2, varint - box(encodeVarint(zigzagEncode(-1234567890123L))), - new Byte[]{(byte) 0x1d}, // field 3, fixed32 - box(encodeFixed32(12345)), - new Byte[]{(byte) 0x22}, // field 4, len-delimited - box(encodeVarint(3)), - box(new byte[]{1, 2, 3})); - - try (Table input = new Table.TestBuilder().column(row0).build(); - // Use fromBoxedLongs then cast to UINT32 since cuDF Java lacks direct UINT32 factory - ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); - ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); - ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); - ColumnVector expectedB = ColumnVector.fromLists( - new ListType(true, new BasicType(true, DType.INT8)), - Arrays.asList((byte) 1, (byte) 2, (byte) 3)); - ColumnVector actualStruct = ProtobufSimple.decodeToStruct( - input.getColumn(0), - new int[]{1, 2, 3, 4}, - new int[]{ - DType.UINT32.getTypeId().getNativeId(), - DType.INT64.getTypeId().getNativeId(), - DType.INT32.getTypeId().getNativeId(), - DType.LIST.getTypeId().getNativeId()}, - new int[]{ - ProtobufSimple.ENC_DEFAULT, - ProtobufSimple.ENC_ZIGZAG, - ProtobufSimple.ENC_FIXED, - ProtobufSimple.ENC_DEFAULT})) { - // Cast expectedU32 from INT64 to UINT32 to match the actual output type - try (ColumnVector expectedU32Correct = expectedU32.castTo(DType.UINT32); - ColumnVector expectedStructCorrect = ColumnVector.makeStruct(expectedU32Correct, expectedS64, expectedF32, expectedB)) { - AssertUtils.assertStructColumnsAreEqual(expectedStructCorrect, actualStruct); - } - } - } -} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java new file mode 100644 index 0000000000..c2d1be0a35 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -0,0 +1,1039 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import ai.rapids.cudf.AssertUtils; +import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector.*; +import ai.rapids.cudf.Table; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +/** + * Tests for the Protobuf GPU decoder. + * + * Test cases are inspired by Google's protobuf conformance test suite: + * https://github.com/protocolbuffers/protobuf/tree/main/conformance + */ +public class ProtobufTest { + + // ============================================================================ + // Helper methods for encoding protobuf wire format + // ============================================================================ + + /** Encode a value as a varint (variable-length integer). */ + private static byte[] encodeVarint(long value) { + long v = value; + byte[] tmp = new byte[10]; + int idx = 0; + while ((v & ~0x7FL) != 0) { + tmp[idx++] = (byte) ((v & 0x7F) | 0x80); + v >>>= 7; + } + tmp[idx++] = (byte) (v & 0x7F); + byte[] out = new byte[idx]; + System.arraycopy(tmp, 0, out, 0, idx); + return out; + } + + /** + * Encode a varint with extra padding bytes (over-encoded but valid). + * This is useful for testing that parsers accept non-canonical varints. + */ + private static byte[] encodeLongVarint(long value, int extraBytes) { + byte[] tmp = new byte[10]; + int idx = 0; + long v = value; + while ((v & ~0x7FL) != 0 || extraBytes > 0) { + tmp[idx++] = (byte) ((v & 0x7F) | 0x80); + v >>>= 7; + if (v == 0 && extraBytes > 0) { + extraBytes--; + } + } + tmp[idx++] = (byte) (v & 0x7F); + byte[] out = new byte[idx]; + System.arraycopy(tmp, 0, out, 0, idx); + return out; + } + + /** ZigZag encode a signed 32-bit integer, returning as unsigned long for varint encoding. */ + private static long zigzagEncode32(int n) { + return Integer.toUnsignedLong((n << 1) ^ (n >> 31)); + } + + /** ZigZag encode a signed 64-bit integer. */ + private static long zigzagEncode64(long n) { + return (n << 1) ^ (n >> 63); + } + + /** Encode a 32-bit value in little-endian (fixed32). */ + private static byte[] encodeFixed32(int v) { + return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(v).array(); + } + + /** Encode a 64-bit value in little-endian (fixed64). */ + private static byte[] encodeFixed64(long v) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(v).array(); + } + + /** Encode a float in little-endian (fixed32 wire type). */ + private static byte[] encodeFloat(float f) { + return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(f).array(); + } + + /** Encode a double in little-endian (fixed64 wire type). */ + private static byte[] encodeDouble(double d) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putDouble(d).array(); + } + + /** Create a protobuf tag (field number + wire type). */ + private static byte[] tag(int fieldNumber, int wireType) { + return encodeVarint((fieldNumber << 3) | wireType); + } + + // Wire type constants + private static final int WT_VARINT = 0; + private static final int WT_64BIT = 1; + private static final int WT_LEN = 2; + private static final int WT_32BIT = 5; + + private static Byte[] box(byte[] bytes) { + if (bytes == null) return null; + Byte[] out = new Byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + out[i] = bytes[i]; + } + return out; + } + + private static Byte[] concat(Byte[]... parts) { + int len = 0; + for (Byte[] p : parts) if (p != null) len += p.length; + Byte[] out = new Byte[len]; + int pos = 0; + for (Byte[] p : parts) { + if (p != null) { + System.arraycopy(p, 0, out, pos, p.length); + pos += p.length; + } + } + return out; + } + + // ============================================================================ + // Basic Type Tests + // ============================================================================ + + @Test + void decodeVarintAndStringToStruct() { + // message Msg { int64 id = 1; string name = 2; } + // Row0: id=100, name="alice" + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(100)), + box(tag(2, WT_LEN)), + box(encodeVarint(5)), + box("alice".getBytes())); + + // Row1: id=200, name missing + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(200))); + + // Row2: null input message + Byte[] row2 = null; + + try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); + ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); + ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void decodeMoreTypes() { + // message Msg { uint32 u32 = 1; sint64 s64 = 2; fixed32 f32 = 3; bytes b = 4; } + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(4000000000L)), + box(tag(2, WT_VARINT)), + box(encodeVarint(zigzagEncode64(-1234567890123L))), + box(tag(3, WT_32BIT)), + box(encodeFixed32(12345)), + box(tag(4, WT_LEN)), + box(encodeVarint(3)), + box(new byte[]{1, 2, 3})); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); + ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); + ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); + ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); + ColumnVector expectedB = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.INT8)), + Arrays.asList((byte) 1, (byte) 2, (byte) 3)); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 2, 3, 4}, + new int[]{ + DType.UINT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.LIST.getTypeId().getNativeId()}, + new int[]{ + Protobuf.ENC_DEFAULT, + Protobuf.ENC_ZIGZAG, + Protobuf.ENC_FIXED, + Protobuf.ENC_DEFAULT})) { + try (ColumnVector expectedU32Correct = expectedU32.castTo(DType.UINT32); + ColumnVector expectedStructCorrect = ColumnVector.makeStruct( + expectedU32Correct, expectedS64, expectedF32, expectedB)) { + AssertUtils.assertStructColumnsAreEqual(expectedStructCorrect, actualStruct); + } + } + } + + @Test + void decodeFloatDoubleAndBool() { + // message Msg { bool flag = 1; float f32 = 2; double f64 = 3; } + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), new Byte[]{(byte)0x01}, // bool=true + box(tag(2, WT_32BIT)), box(encodeFloat(3.14f)), + box(tag(3, WT_64BIT)), box(encodeDouble(2.71828))); + + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), new Byte[]{(byte)0x00}, // bool=false + box(tag(2, WT_32BIT)), box(encodeFloat(-1.5f)), + box(tag(3, WT_64BIT)), box(encodeDouble(0.0))); + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector expectedBool = ColumnVector.fromBoxedBooleans(true, false); + ColumnVector expectedFloat = ColumnVector.fromBoxedFloats(3.14f, -1.5f); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(2.71828, 0.0); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedBool, expectedFloat, expectedDouble); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{ + DType.BOOL8.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId(), + DType.FLOAT64.getTypeId().getNativeId()}, + new int[]{0, 0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Varint Boundary Tests + // ============================================================================ + + @Test + void testVarintMaxUint64() { + // Max uint64 = 0xFFFFFFFFFFFFFFFF = 18446744073709551615 + // Encoded as 10 bytes: FF FF FF FF FF FF FF FF FF 01 + Byte[] row = concat( + box(tag(1, WT_VARINT)), + new Byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0x01}); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.UINT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + try (ColumnVector expectedU64 = ColumnVector.fromBoxedLongs(-1L); // -1 as unsigned = max + ColumnVector expectedU64Correct = expectedU64.castTo(DType.UINT64); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedU64Correct)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + } + + @Test + void testVarintZero() { + // Zero encoded as single byte: 0x00 + Byte[] row = concat(box(tag(1, WT_VARINT)), new Byte[]{0x00}); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(0L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testVarintOverEncodedZero() { + // Zero over-encoded as 10 bytes (all continuation bits except last) + // This is valid per protobuf spec - parsers must accept non-canonical varints + Byte[] row = concat( + box(tag(1, WT_VARINT)), + new Byte[]{(byte)0x80, (byte)0x80, (byte)0x80, (byte)0x80, (byte)0x80, + (byte)0x80, (byte)0x80, (byte)0x80, (byte)0x80, (byte)0x00}); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(0L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testVarint10thByteInvalid() { + // 10th byte with more than 1 significant bit is invalid + // (uint64 can only hold 64 bits: 9*7=63 bits + 1 bit from 10th byte) + Byte[] row = concat( + box(tag(1, WT_VARINT)), + new Byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0x02}); // 0x02 has 2nd bit set + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + // ============================================================================ + // ZigZag Boundary Tests + // ============================================================================ + + @Test + void testZigzagInt32Min() { + // int32 min = -2147483648 + // zigzag encoded = 4294967295 = 0xFFFFFFFF + int minInt32 = Integer.MIN_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode32(minInt32)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(minInt32); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagInt32Max() { + // int32 max = 2147483647 + // zigzag encoded = 4294967294 = 0xFFFFFFFE + int maxInt32 = Integer.MAX_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode32(maxInt32)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(maxInt32); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagInt64Min() { + // int64 min = -9223372036854775808 + long minInt64 = Long.MIN_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode64(minInt64)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(minInt64); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagInt64Max() { + long maxInt64 = Long.MAX_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode64(maxInt64)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(maxInt64); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagNegativeOne() { + // -1 zigzag encoded = 1 + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode64(-1L)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(-1L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Truncated/Malformed Data Tests + // ============================================================================ + + @Test + void testMalformedVarint() { + // Varint that never terminates (all continuation bits set, 11 bytes) + Byte[] malformed = new Byte[]{(byte)0x08, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testTruncatedVarint() { + // Single byte with continuation bit set but no following byte + Byte[] truncated = concat(box(tag(1, WT_VARINT)), new Byte[]{(byte)0x80}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testTruncatedLengthDelimited() { + // String field with length=5 but no actual data + Byte[] truncated = concat(box(tag(2, WT_LEN)), box(encodeVarint(5))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{2}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + try (ColumnVector expected = ColumnVector.fromStrings((String)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testTruncatedFixed32() { + // Fixed32 needs 4 bytes but only 3 provided + Byte[] truncated = concat(box(tag(1, WT_32BIT)), new Byte[]{0x01, 0x02, 0x03}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedInts((Integer)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testTruncatedFixed64() { + // Fixed64 needs 8 bytes but only 7 provided + Byte[] truncated = concat(box(tag(1, WT_64BIT)), + new Byte[]{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testPartialLengthDelimitedData() { + // Length says 10 bytes but only 5 provided + Byte[] partial = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(10)), + box("hello".getBytes())); // only 5 bytes + try (Table input = new Table.TestBuilder().column(new Byte[][]{partial}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + try (ColumnVector expected = ColumnVector.fromStrings((String)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + // ============================================================================ + // Wrong Wire Type Tests + // ============================================================================ + + @Test + void testWrongWireType() { + // Expect varint (wire type 0) but provide fixed32 (wire type 5) + Byte[] wrongType = concat( + box(tag(1, WT_32BIT)), // wire type 5 instead of 0 + box(encodeFixed32(100))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{wrongType}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, // expects varint + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testWrongWireTypeForString() { + // Expect length-delimited (wire type 2) but provide varint (wire type 0) + Byte[] wrongType = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(12345))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{wrongType}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, // expects LEN + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + try (ColumnVector expected = ColumnVector.fromStrings((String)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + // ============================================================================ + // Unknown Field Skip Tests + // ============================================================================ + + @Test + void testSkipUnknownVarintField() { + // Unknown field 99 with varint, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_VARINT)), + box(encodeVarint(12345)), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); // known field + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSkipUnknownFixed64Field() { + // Unknown field 99 with fixed64, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_64BIT)), + box(encodeFixed64(0x123456789ABCDEF0L)), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSkipUnknownLengthDelimitedField() { + // Unknown field 99 with length-delimited data, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_LEN)), + box(encodeVarint(5)), + box("hello".getBytes()), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSkipUnknownFixed32Field() { + // Unknown field 99 with fixed32, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_32BIT)), + box(encodeFixed32(12345)), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Last One Wins (Repeated Scalar Field) Tests + // ============================================================================ + + @Test + void testLastOneWins() { + // Same field appears multiple times - last value should win + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100)), + box(tag(1, WT_VARINT)), box(encodeVarint(200)), + box(tag(1, WT_VARINT)), box(encodeVarint(300))); // this should win + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(300L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testLastOneWinsForString() { + // Same string field appears multiple times + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(5)), box("first".getBytes()), + box(tag(1, WT_LEN)), box(encodeVarint(6)), box("second".getBytes()), + box(tag(1, WT_LEN)), box(encodeVarint(4)), box("last".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("last"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Error Handling Tests + // ============================================================================ + + @Test + void testFailOnErrorsTrue() { + Byte[] malformed = new Byte[]{(byte)0x08, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFieldNumberZeroInvalid() { + // Field number 0 is reserved and invalid + Byte[] invalid = concat(box(tag(0, WT_VARINT)), box(encodeVarint(123))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{invalid}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + @Test + void testEmptyMessage() { + // Empty message should result in null/default values for all fields + Byte[] empty = new Byte[0]; + try (Table input = new Table.TestBuilder().column(new Byte[][]{empty}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStr = ColumnVector.fromStrings((String)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt, expectedStr); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Float/Double Special Values Tests + // ============================================================================ + + @Test + void testFloatSpecialValues() { + Byte[] rowInf = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.POSITIVE_INFINITY))); + Byte[] rowNegInf = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.NEGATIVE_INFINITY))); + Byte[] rowNaN = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.NaN))); + Byte[] rowMin = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.MIN_VALUE))); + Byte[] rowMax = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.MAX_VALUE))); + + try (Table input = new Table.TestBuilder().column(rowInf, rowNegInf, rowNaN, rowMin, rowMax).build(); + ColumnVector expectedFloat = ColumnVector.fromBoxedFloats( + Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, Float.NaN, + Float.MIN_VALUE, Float.MAX_VALUE); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedFloat); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.FLOAT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDoubleSpecialValues() { + Byte[] rowInf = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.POSITIVE_INFINITY))); + Byte[] rowNegInf = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.NEGATIVE_INFINITY))); + Byte[] rowNaN = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.NaN))); + Byte[] rowMin = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.MIN_VALUE))); + Byte[] rowMax = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.MAX_VALUE))); + + try (Table input = new Table.TestBuilder().column(rowInf, rowNegInf, rowNaN, rowMin, rowMax).build(); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles( + Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NaN, + Double.MIN_VALUE, Double.MAX_VALUE); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedDouble); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.FLOAT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Tests for Features Not Yet Implemented (Disabled) + // ============================================================================ + + @Disabled("Unpacked repeated fields not yet implemented") + @Test + void testUnpackedRepeatedInt32() { + // Unpacked repeated: same field number appears multiple times + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(1, WT_VARINT)), box(encodeVarint(3))); + + // Expected: ARRAY with values [1, 2, 3] + // (Currently we implement "last one wins" semantics for scalars) + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + // TODO: implement unpacked repeated field decoding + } + } + + @Disabled("Nested messages not yet implemented") + @Test + void testNestedMessage() { + // message Inner { int32 x = 1; } + // message Outer { Inner inner = 1; } + // Outer with inner.x = 42 + Byte[] innerMessage = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(innerMessage.length)), + innerMessage); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + // TODO: implement nested message decoding + // Expected: STRUCT> + } + } + + @Disabled("Large field numbers not tested with current API") + @Test + void testLargeFieldNumber() { + // Field numbers can be up to 2^29 - 1 = 536870911 + int largeFieldNum = 536870911; + Byte[] row = concat( + box(tag(largeFieldNum, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + // Current API uses int[] for field numbers, should work + // But need to verify kernel handles large field numbers correctly + } + } + + // ============================================================================ + // FAILFAST Mode Tests (failOnErrors = true) + // ============================================================================ + + @Test + void testFailfastMalformedVarint() { + // Varint that never terminates (all continuation bits set) + Byte[] malformed = new Byte[]{(byte)0x08, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { // failOnErrors = true + } + }); + } + } + + @Test + void testFailfastTruncatedVarint() { + // Single byte with continuation bit set but no following byte + Byte[] truncated = concat(box(tag(1, WT_VARINT)), new Byte[]{(byte)0x80}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFailfastTruncatedString() { + // String field with length=5 but no actual data + Byte[] truncated = concat(box(tag(2, WT_LEN)), box(encodeVarint(5))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{2}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFailfastTruncatedFixed32() { + // Fixed32 needs 4 bytes but only 3 provided + Byte[] truncated = concat(box(tag(1, WT_32BIT)), new Byte[]{0x01, 0x02, 0x03}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + true)) { + } + }); + } + } + + @Test + void testFailfastTruncatedFixed64() { + // Fixed64 needs 8 bytes but only 5 provided + Byte[] truncated = concat(box(tag(1, WT_64BIT)), new Byte[]{0x01, 0x02, 0x03, 0x04, 0x05}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + true)) { + } + }); + } + } + + @Test + void testFailfastWrongWireType() { + // Field 1 with wire type 2 (length-delimited), but we request varint + Byte[] row = concat(box(tag(1, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes())); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + true)) { + } + }); + } + } + + @Test + void testFailfastFieldNumberZero() { + // Field number 0 is invalid in protobuf + Byte[] row = concat(box(tag(0, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFailfastValidDataDoesNotThrow() { + // Valid protobuf should not throw even with failOnErrors = true + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } +} From 3701cdf14aff8649bf502f8dcc15563c40068630 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 9 Jan 2026 16:15:27 +0800 Subject: [PATCH 009/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 12 ++++++------ src/main/cpp/src/protobuf.cu | 19 ++++++++----------- 2 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index d36e791b51..17cc602d7c 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -25,12 +25,12 @@ extern "C" { JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, - jclass, - jlong binary_input_view, - jintArray field_numbers, - jintArray type_ids, - jintArray type_scales, - jboolean fail_on_errors) + jclass, + jlong binary_input_view, + jintArray field_numbers, + jintArray type_ids, + jintArray type_scales, + jboolean fail_on_errors) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 7bac1ec74f..8c7394c5b2 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -579,19 +579,16 @@ std::unique_ptr decode_protobuf_to_struct( auto chars_size = static_cast(scv.chars_size(stream)); rmm::device_uvector chars_vec(chars_size, stream, mr); if (chars_size > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(chars_vec.data(), - chars_data, - chars_size, - cudaMemcpyDeviceToDevice, - stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + chars_vec.data(), chars_data, chars_size, cudaMemcpyDeviceToDevice, stream.value())); } // Create INT8 column from chars data - auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::INT8}, - chars_size, - chars_vec.release(), - rmm::device_buffer{}, // no null mask for chars - 0); // no nulls + auto child_col = + std::make_unique(cudf::data_type{cudf::type_id::INT8}, + chars_size, + chars_vec.release(), + rmm::device_buffer{}, // no null mask for chars + 0); // no nulls // Get null mask auto null_mask = cudf::copy_bitmask(*strings, stream, mr); From 1bf76602a3e8bd8775dac938264af3ddbe36ebba Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 13 Jan 2026 11:23:50 +0800 Subject: [PATCH 010/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 17 ++++++++++++++++- src/main/cpp/src/protobuf.hpp | 5 ++++- .../com/nvidia/spark/rapids/jni/Protobuf.java | 5 +++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 8c7394c5b2..f8c581757b 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -407,9 +407,21 @@ std::unique_ptr decode_protobuf_to_struct( auto const stream = cudf::get_default_stream(); auto mr = cudf::get_current_device_resource_ref(); - auto d_in = cudf::column_device_view::create(binary_input, stream); auto rows = binary_input.size(); + // Handle zero-row case explicitly - return empty STRUCT with properly typed children + if (rows == 0) { + std::vector> empty_children; + empty_children.reserve(out_types.size()); + for (auto const& dt : out_types) { + empty_children.push_back(cudf::make_empty_column(dt)); + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + auto d_in = cudf::column_device_view::create(binary_input, stream); + // Track parse errors across kernels. rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); @@ -610,6 +622,9 @@ std::unique_ptr decode_protobuf_to_struct( CUDF_CUDA_TRY(cudaPeekAtLastError()); // Check for any parse errors. + // Note: We check errors after all kernels complete rather than between kernel launches + // to avoid expensive synchronization overhead. If fail_on_errors is true and an error + // occurred, all kernels will have executed but we throw an exception here. int h_error = 0; CUDF_CUDA_TRY( cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index d77a13c653..15c27e7f0a 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -53,7 +53,10 @@ namespace spark_rapids_jni { * @param field_numbers protobuf field numbers (one per output child) * @param out_types output cudf data types (one per output child) * @param encodings encoding type for each field (0=default, 1=fixed, 2=zigzag) - * @param fail_on_errors whether to throw on malformed messages + * @param fail_on_errors whether to throw on malformed messages. Note: error checking is performed + * after all kernels complete (not between kernel launches) to avoid synchronization + * overhead. If an error is detected, all kernels will have executed but an exception will be + * thrown. * @return STRUCT column with the given children types; the STRUCT itself is always non-null, * and individual child fields may be null when input message is null or field is missing */ diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 8b9435a31b..a091bdbcbd 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -76,6 +76,8 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, * - For decimal types, this is the scale (currently unsupported). * @param failOnErrors if true, throw an exception on malformed protobuf messages. * If false, return nulls for fields that cannot be parsed. + * Note: error checking is performed after all fields are processed, + * not between fields, to avoid synchronization overhead. * @return a cudf STRUCT column where children correspond 1:1 with {@code fieldNumbers}/{@code typeIds}. */ public static ColumnVector decodeToStruct(ColumnView binaryInput, @@ -86,6 +88,9 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, if (fieldNumbers == null || typeIds == null || typeScales == null) { throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be non-null"); } + if (fieldNumbers.length == 0) { + throw new IllegalArgumentException("fieldNumbers must not be empty"); + } if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); } From b01242117d7d1dbc0b05e0e02a7637d708d13f71 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 20 Jan 2026 19:34:57 +0800 Subject: [PATCH 011/107] multi column processing Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 57 +- src/main/cpp/src/protobuf.cu | 1124 +++++++++++------ src/main/cpp/src/protobuf.hpp | 36 +- .../com/nvidia/spark/rapids/jni/Protobuf.java | 141 ++- .../nvidia/spark/rapids/jni/ProtobufTest.java | 214 +++- 5 files changed, 1054 insertions(+), 518 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 17cc602d7c..d2b20923ab 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -27,51 +27,58 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, jclass, jlong binary_input_view, + jint total_num_fields, + jintArray decoded_field_indices, jintArray field_numbers, - jintArray type_ids, - jintArray type_scales, + jintArray all_type_ids, + jintArray encodings, jboolean fail_on_errors) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); + JNI_NULL_CHECK(env, decoded_field_indices, "decoded_field_indices is null", 0); JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); - JNI_NULL_CHECK(env, type_ids, "type_ids is null", 0); - JNI_NULL_CHECK(env, type_scales, "type_scales is null", 0); + JNI_NULL_CHECK(env, all_type_ids, "all_type_ids is null", 0); + JNI_NULL_CHECK(env, encodings, "encodings is null", 0); JNI_TRY { cudf::jni::auto_set_device(env); auto const* input = reinterpret_cast(binary_input_view); + + cudf::jni::native_jintArray n_decoded_indices(env, decoded_field_indices); cudf::jni::native_jintArray n_field_numbers(env, field_numbers); - cudf::jni::native_jintArray n_type_ids(env, type_ids); - cudf::jni::native_jintArray n_type_scales(env, type_scales); - if (n_field_numbers.size() != n_type_ids.size() || - n_field_numbers.size() != n_type_scales.size()) { + cudf::jni::native_jintArray n_all_type_ids(env, all_type_ids); + cudf::jni::native_jintArray n_encodings(env, encodings); + + // Validate array sizes + if (n_decoded_indices.size() != n_field_numbers.size() || + n_decoded_indices.size() != n_encodings.size()) { JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "fieldNumbers/typeIds/typeScales must be the same length", + "decoded_field_indices/field_numbers/encodings must be the same length", + 0); + } + if (n_all_type_ids.size() != total_num_fields) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "all_type_ids size must equal total_num_fields", 0); } + std::vector decoded_indices(n_decoded_indices.begin(), n_decoded_indices.end()); std::vector field_nums(n_field_numbers.begin(), n_field_numbers.end()); - std::vector encodings(n_type_scales.begin(), n_type_scales.end()); - std::vector out_types; - out_types.reserve(n_type_ids.size()); - for (int i = 0; i < n_type_ids.size(); ++i) { - // For protobuf decoding, typeScales contains encoding info (0=default, 1=fixed, - // 2=zigzag) not decimal scales. For non-decimal types, scale should be 0. Decimal types are - // not currently supported in protobuf decoder. - auto type_id = static_cast(n_type_ids[i]); - if (cudf::is_fixed_point(cudf::data_type{type_id})) { - // For decimal types, use the scale from typeScales (though currently unsupported) - out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], n_type_scales[i])); - } else { - // For non-decimal types, scale is always 0; typeScales contains encoding info - out_types.emplace_back(cudf::jni::make_data_type(n_type_ids[i], 0)); - } + std::vector encs(n_encodings.begin(), n_encodings.end()); + + // Build all_types vector - types for ALL fields in the output struct + std::vector all_types; + all_types.reserve(total_num_fields); + for (int i = 0; i < total_num_fields; ++i) { + // For non-decimal types, scale is always 0 + all_types.emplace_back(cudf::jni::make_data_type(n_all_type_ids[i], 0)); } auto result = spark_rapids_jni::decode_protobuf_to_struct( - *input, field_nums, out_types, encodings, fail_on_errors); + *input, total_num_fields, decoded_indices, field_nums, all_types, encs, fail_on_errors); return cudf::jni::release_as_jlong(result); } JNI_CATCH(env, 0); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index f8c581757b..e0d2a087b7 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -33,9 +33,11 @@ #include #include #include +#include namespace { +// Wire type constants constexpr int WT_VARINT = 0; constexpr int WT_64BIT = 1; constexpr int WT_LEN = 2; @@ -43,28 +45,40 @@ constexpr int WT_32BIT = 5; } // namespace -namespace spark_rapids_jni { +namespace { -constexpr int ENC_DEFAULT = 0; -constexpr int ENC_FIXED = 1; -constexpr int ENC_ZIGZAG = 2; +/** + * Structure to record field location within a message. + * offset < 0 means field was not found. + */ +struct field_location { + int32_t offset; // Offset of field data within the message (-1 if not found) + int32_t length; // Length of field data in bytes +}; -} // namespace spark_rapids_jni +/** + * Field descriptor passed to the scanning kernel. + */ +struct field_descriptor { + int field_number; // Protobuf field number + int expected_wire_type; // Expected wire type for this field +}; -namespace { +// ============================================================================ +// Device helper functions +// ============================================================================ __device__ inline bool read_varint(uint8_t const* cur, uint8_t const* end, uint64_t& out, int& bytes) { - out = 0; - bytes = 0; + out = 0; + bytes = 0; int shift = 0; while (cur < end && bytes < 10) { uint8_t b = *cur++; // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid - // since we can only fit 1 more bit into uint64_t if (bytes == 9 && (b & 0xFE) != 0) { return false; // Invalid: 10th byte has more than 1 significant bit } @@ -76,43 +90,51 @@ __device__ inline bool read_varint(uint8_t const* cur, return false; } -__device__ inline bool skip_field(uint8_t const* cur, - uint8_t const* end, - int wt, - uint8_t const*& out_cur) +__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { - out_cur = cur; switch (wt) { case WT_VARINT: { - uint64_t tmp; - int n; - if (!read_varint(out_cur, end, tmp, n)) return false; - out_cur += n; - return true; + // Need to scan to find the end of varint + int count = 0; + while (cur < end && count < 10) { + if ((*cur++ & 0x80u) == 0) { return count + 1; } + count++; + } + return -1; // Invalid varint } case WT_64BIT: - if (end - out_cur < 8) return false; - out_cur += 8; - return true; + // Check if there's enough data for 8 bytes + if (end - cur < 8) return -1; + return 8; case WT_32BIT: - if (end - out_cur < 4) return false; - out_cur += 4; - return true; + // Check if there's enough data for 4 bytes + if (end - cur < 4) return -1; + return 4; case WT_LEN: { - uint64_t len64; + uint64_t len; int n; - if (!read_varint(out_cur, end, len64, n)) return false; - out_cur += n; - // Check for both buffer overflow and int overflow - if (len64 > static_cast(end - out_cur) || len64 > static_cast(INT_MAX)) - return false; - out_cur += static_cast(len64); - return true; + if (!read_varint(cur, end, len, n)) return -1; + if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX)) + return -1; + return n + static_cast(len); } - default: return false; + default: return -1; } } +__device__ inline bool skip_field(uint8_t const* cur, + uint8_t const* end, + int wt, + uint8_t const*& out_cur) +{ + int size = get_wire_type_size(wt, cur, end); + if (size < 0) return false; + // Ensure we don't skip past the end of the buffer + if (cur + size > end) return false; + out_cur = cur + size; + return true; +} + template __device__ inline T load_le(uint8_t const* p); @@ -134,243 +156,271 @@ __device__ inline uint64_t load_le(uint8_t const* p) return v; } -template -__global__ void extract_varint_kernel( - cudf::column_device_view const d_in, int field_number, OutT* out, bool* valid, int* error_flag) +// ============================================================================ +// Pass 1: Scan all fields kernel - records (offset, length) for each field +// ============================================================================ + +/** + * Fused scanning kernel: scans each message once and records the location + * of all requested fields. + * + * For "last one wins" semantics (protobuf standard for repeated scalars), + * we continue scanning even after finding a field. + */ +__global__ void scan_all_fields_kernel( + cudf::column_device_view const d_in, + field_descriptor const* field_descs, // [num_fields] + int num_fields, + field_location* locations, // [num_rows * num_fields] row-major + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; if (row >= in.size()) return; + + // Initialize all field locations to "not found" + for (int f = 0; f < num_fields; f++) { + locations[row * num_fields + f] = {-1, 0}; + } + if (in.nullable() && in.is_null(row)) { - valid[row] = false; - return; + return; // Null input row - all fields remain "not found" } - // Use sliced child + offsets normalized to the slice start to correctly handle - // list columns with non-zero row offsets (and any child offsets). auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); auto start = in.offset_at(row) - base; auto end = in.offset_at(row + 1) - base; - // Defensive bounds checks: if offsets are inconsistent, avoid illegal memory access. + + // Bounds check if (start < 0 || end < start || end > child.size()) { atomicExch(error_flag, 1); - valid[row] = false; return; } + uint8_t const* cur = bytes + start; uint8_t const* stop = bytes + end; - bool found = false; - OutT value{}; + // Scan the message once, recording locations of all target fields while (cur < stop) { uint64_t key; int key_bytes; if (!read_varint(cur, stop, key, key_bytes)) { atomicExch(error_flag, 1); - break; + return; } cur += key_bytes; + int fn = static_cast(key >> 3); int wt = static_cast(key & 0x7); + if (fn == 0) { atomicExch(error_flag, 1); - break; + return; } - if (fn == field_number) { - if (wt != WT_VARINT) { - atomicExch(error_flag, 1); - break; - } - uint64_t v; - int n; - if (!read_varint(cur, stop, v, n)) { - atomicExch(error_flag, 1); - break; - } - cur += n; - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - value = static_cast(v); - found = true; - // Continue scanning to allow "last one wins" semantics. - } else { - uint8_t const* next; - if (!skip_field(cur, stop, wt, next)) { - atomicExch(error_flag, 1); - break; + + // Check if this field is one we're looking for + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == fn) { + // Check wire type matches + if (wt != field_descs[f].expected_wire_type) { + atomicExch(error_flag, 1); + return; + } + + // Record the location (relative to message start) + int data_offset = static_cast(cur - bytes - start); + + if (wt == WT_LEN) { + // For length-delimited, record offset after length prefix and the data length + uint64_t len; + int len_bytes; + if (!read_varint(cur, stop, len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + if (len > static_cast(stop - cur - len_bytes) || + len > static_cast(INT_MAX)) { + atomicExch(error_flag, 1); + return; + } + // Record offset pointing to the actual data (after length prefix) + locations[row * num_fields + f] = {data_offset + len_bytes, + static_cast(len)}; + } else { + // For fixed-size and varint fields, record offset and compute length + int field_size = get_wire_type_size(wt, cur, stop); + if (field_size < 0) { + atomicExch(error_flag, 1); + return; + } + locations[row * num_fields + f] = {data_offset, field_size}; + } + // Don't break - continue to support "last one wins" semantics } - cur = next; } - } - if (found) { - out[row] = value; - valid[row] = true; - } else { - valid[row] = false; + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, stop, wt, next)) { + atomicExch(error_flag, 1); + return; + } + cur = next; } } -template -__global__ void extract_fixed_kernel( - cudf::column_device_view const d_in, int field_number, OutT* out, bool* valid, int* error_flag) +// ============================================================================ +// Pass 2: Extract data kernels +// ============================================================================ + +/** + * Extract varint field data using pre-recorded locations. + */ +template +__global__ void extract_varint_from_locations_kernel( + uint8_t const* message_data, + cudf::size_type const* offsets, // List offsets for each row + cudf::size_type base_offset, + field_location const* locations, // [num_rows * num_fields] + int field_idx, + int num_fields, + OutT* out, + bool* valid, + int num_rows, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; - if (in.nullable() && in.is_null(row)) { + if (row >= num_rows) return; + + auto loc = locations[row * num_fields + field_idx]; + if (loc.offset < 0) { valid[row] = false; return; } - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); - auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; - if (start < 0 || end < start || end > child.size()) { + // Calculate absolute offset in the message data + auto row_start = offsets[row] - base_offset; + uint8_t const* cur = message_data + row_start + loc.offset; + uint8_t const* cur_end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { atomicExch(error_flag, 1); valid[row] = false; return; } - uint8_t const* cur = bytes + start; - uint8_t const* stop = bytes + end; - - bool found = false; - OutT value{}; - while (cur < stop) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, stop, key, key_bytes)) { - atomicExch(error_flag, 1); - break; - } - cur += key_bytes; - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - if (fn == 0) { - atomicExch(error_flag, 1); - break; - } - if (fn == field_number) { - if (wt != WT) { - atomicExch(error_flag, 1); - break; - } - if constexpr (WT == WT_32BIT) { - if (stop - cur < 4) { - atomicExch(error_flag, 1); - break; - } - uint32_t raw = load_le(cur); - cur += 4; - // Use memcpy to avoid undefined behavior from type punning - memcpy(&value, &raw, sizeof(value)); - } else { - if (stop - cur < 8) { - atomicExch(error_flag, 1); - break; - } - uint64_t raw = load_le(cur); - cur += 8; - // Use memcpy to avoid undefined behavior from type punning - memcpy(&value, &raw, sizeof(value)); - } - found = true; - } else { - uint8_t const* next; - if (!skip_field(cur, stop, wt, next)) { - atomicExch(error_flag, 1); - break; - } - cur = next; - } - } - if (found) { - out[row] = value; - valid[row] = true; - } else { - valid[row] = false; - } + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + out[row] = static_cast(v); + valid[row] = true; } -__global__ void extract_string_kernel(cudf::column_device_view const d_in, - int field_number, - cudf::strings::detail::string_index_pair* out_pairs, - int* error_flag) +/** + * Extract fixed-size field data (fixed32, fixed64, float, double). + */ +template +__global__ void extract_fixed_from_locations_kernel( + uint8_t const* message_data, + cudf::size_type const* offsets, + cudf::size_type base_offset, + field_location const* locations, + int field_idx, + int num_fields, + OutT* out, + bool* valid, + int num_rows, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; - if (in.nullable() && in.is_null(row)) { - out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; - return; - } + if (row >= num_rows) return; - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); - auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; - if (start < 0 || end < start || end > child.size()) { - atomicExch(error_flag, 1); - out_pairs[row] = cudf::strings::detail::string_index_pair{nullptr, 0}; + auto loc = locations[row * num_fields + field_idx]; + if (loc.offset < 0) { + valid[row] = false; return; } - uint8_t const* cur = bytes + start; - uint8_t const* stop = bytes + end; - cudf::strings::detail::string_index_pair pair{nullptr, 0}; - while (cur < stop) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, stop, key, key_bytes)) { + auto row_start = offsets[row] - base_offset; + uint8_t const* cur = message_data + row_start + loc.offset; + + OutT value; + if constexpr (WT == WT_32BIT) { + if (loc.length < 4) { atomicExch(error_flag, 1); - break; + valid[row] = false; + return; } - cur += key_bytes; - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - if (fn == 0) { + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { atomicExch(error_flag, 1); - break; - } - if (fn == field_number) { - if (wt != WT_LEN) { - atomicExch(error_flag, 1); - break; - } - uint64_t len64; - int n; - if (!read_varint(cur, stop, len64, n)) { - atomicExch(error_flag, 1); - break; - } - cur += n; - // Check for both buffer overflow and int overflow - if (len64 > static_cast(stop - cur) || len64 > static_cast(INT_MAX)) { - atomicExch(error_flag, 1); - break; - } - pair.first = reinterpret_cast(cur); - pair.second = static_cast(len64); - cur += static_cast(len64); - // Continue scanning to allow "last one wins". - } else { - uint8_t const* next; - if (!skip_field(cur, stop, wt, next)) { - atomicExch(error_flag, 1); - break; - } - cur = next; + valid[row] = false; + return; } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[row] = value; + valid[row] = true; +} + +/** + * Kernel to copy variable-length data (string/bytes) to output buffer. + * Uses pre-computed output offsets from prefix sum. + */ +__global__ void copy_varlen_data_kernel( + uint8_t const* message_data, + cudf::size_type const* input_offsets, // List offsets for input rows + cudf::size_type base_offset, + field_location const* locations, + int field_idx, + int num_fields, + int32_t const* output_offsets, // Pre-computed output offsets (prefix sum) + char* output_data, + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto loc = locations[row * num_fields + field_idx]; + if (loc.offset < 0 || loc.length == 0) return; + + auto row_start = input_offsets[row] - base_offset; + uint8_t const* src = message_data + row_start + loc.offset; + char* dst = output_data + output_offsets[row]; + + // Copy data + for (int i = 0; i < loc.length; i++) { + dst[i] = static_cast(src[i]); } +} - out_pairs[row] = pair; +/** + * Kernel to extract lengths from locations for prefix sum. + */ +__global__ void extract_lengths_kernel( + field_location const* locations, + int field_idx, + int num_fields, + int32_t* lengths, + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto loc = locations[row * num_fields + field_idx]; + lengths[row] = (loc.offset >= 0) ? loc.length : 0; } +// ============================================================================ +// Utility functions +// ============================================================================ + inline std::pair make_null_mask_from_valid( rmm::device_uvector const& valid, rmm::cuda_stream_view stream, @@ -382,14 +432,112 @@ inline std::pair make_null_mask_from_valid( return cudf::detail::valid_if(begin, end, pred, stream, mr); } +/** + * Get the expected wire type for a given cudf type and encoding. + */ +int get_expected_wire_type(cudf::type_id type_id, int encoding) +{ + switch (type_id) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + if (encoding == spark_rapids_jni::ENC_FIXED) { + return (type_id == cudf::type_id::INT32 || type_id == cudf::type_id::UINT32) ? WT_32BIT + : WT_64BIT; + } + return WT_VARINT; + case cudf::type_id::FLOAT32: return WT_32BIT; + case cudf::type_id::FLOAT64: return WT_64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: return WT_LEN; + default: CUDF_FAIL("Unsupported type for protobuf decoding"); + } +} + +/** + * Create an all-null column of the specified type. + */ +std::unique_ptr make_null_column( + cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0) { return cudf::make_empty_column(dtype); } + + switch (dtype.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT8: + case cudf::type_id::UINT8: + case cudf::type_id::INT16: + case cudf::type_id::UINT16: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: { + auto data = rmm::device_buffer(cudf::size_of(dtype) * num_rows, stream, mr); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return std::make_unique( + dtype, num_rows, std::move(data), std::move(null_mask), num_rows); + } + case cudf::type_id::STRING: { + // Create empty strings column with all nulls + rmm::device_uvector pairs(num_rows, stream, mr); + thrust::fill(rmm::exec_policy(stream), + pairs.begin(), + pairs.end(), + cudf::strings::detail::string_index_pair{nullptr, 0}); + return cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); + } + case cudf::type_id::LIST: { + // Create LIST with all nulls + // Offsets: all zeros + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + + // Empty child + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::INT8}, + 0, + rmm::device_buffer{}, + rmm::device_buffer{}, + 0); + + // All null mask + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + num_rows, + std::move(null_mask), + stream, + mr); + } + default: CUDF_FAIL("Unsupported type for null column creation"); + } +} + } // namespace namespace spark_rapids_jni { std::unique_ptr decode_protobuf_to_struct( cudf::column_view const& binary_input, + int total_num_fields, + std::vector const& decoded_field_indices, std::vector const& field_numbers, - std::vector const& out_types, + std::vector const& all_types, std::vector const& encodings, bool fail_on_errors) { @@ -399,13 +547,39 @@ std::unique_ptr decode_protobuf_to_struct( auto const child_type = in_list.child().type().id(); CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, "binary_input must be a LIST column"); - CUDF_EXPECTS(field_numbers.size() == out_types.size(), - "field_numbers and out_types must have the same length"); - CUDF_EXPECTS(encodings.size() == out_types.size(), - "encodings and out_types must have the same length"); + CUDF_EXPECTS(static_cast(all_types.size()) == total_num_fields, + "all_types size must equal total_num_fields"); + CUDF_EXPECTS(decoded_field_indices.size() == field_numbers.size(), + "decoded_field_indices and field_numbers must have the same length"); + CUDF_EXPECTS(encodings.size() == field_numbers.size(), + "encodings and field_numbers must have the same length"); auto const stream = cudf::get_default_stream(); auto mr = cudf::get_current_device_resource_ref(); + auto rows = binary_input.size(); + auto num_decoded_fields = static_cast(field_numbers.size()); + + // Handle zero-row case + if (rows == 0) { + std::vector> empty_children; + empty_children.reserve(total_num_fields); + for (auto const& dt : all_types) { + empty_children.push_back(cudf::make_empty_column(dt)); + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + // Handle case with no fields to decode + if (num_decoded_fields == 0) { + std::vector> null_children; + null_children.reserve(total_num_fields); + for (auto const& dt : all_types) { + null_children.push_back(make_null_column(dt, rows, stream, mr)); + } + return cudf::make_structs_column( + rows, std::move(null_children), 0, rmm::device_buffer{}, stream, mr); + } auto rows = binary_input.size(); @@ -422,203 +596,362 @@ std::unique_ptr decode_protobuf_to_struct( auto d_in = cudf::column_device_view::create(binary_input, stream); - // Track parse errors across kernels. + // Prepare field descriptors for the scanning kernel + std::vector h_field_descs(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; i++) { + int schema_idx = decoded_field_indices[i]; + h_field_descs[i].field_number = field_numbers[i]; + h_field_descs[i].expected_wire_type = + get_expected_wire_type(all_types[schema_idx].id(), encodings[i]); + } + + rmm::device_uvector d_field_descs(num_decoded_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), + h_field_descs.data(), + num_decoded_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + // Allocate field locations array: [rows * num_decoded_fields] + rmm::device_uvector d_locations( + static_cast(rows) * num_decoded_fields, stream, mr); + + // Track errors rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - std::vector> children; - children.reserve(out_types.size()); - auto const threads = 256; auto const blocks = static_cast((rows + threads - 1) / threads); - for (std::size_t i = 0; i < out_types.size(); ++i) { - auto const fn = field_numbers[i]; - auto const dt = out_types[i]; - auto const enc = encodings[i]; - switch (dt.id()) { - case cudf::type_id::BOOL8: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_DEFAULT) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for BOOL8 protobuf field"); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::INT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_ZIGZAG) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else if (enc == ENC_FIXED) { - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else if (enc == ENC_DEFAULT) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for INT32 protobuf field"); + // ========================================================================= + // Pass 1: Scan all messages and record field locations + // ========================================================================= + scan_all_fields_kernel<<>>( + *d_in, d_field_descs.data(), num_decoded_fields, d_locations.data(), d_error.data()); + + // Get message data pointer and offsets for pass 2 + auto const* message_data = + reinterpret_cast(in_list.child().data()); + auto const* list_offsets = in_list.offsets().data(); + // Get the base offset by copying from device to host + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, + list_offsets, + sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, + stream.value())); + stream.synchronize(); + + // ========================================================================= + // Pass 2: Extract data for each field + // ========================================================================= + std::vector> all_children(total_num_fields); + int decoded_idx = 0; + + for (int schema_idx = 0; schema_idx < total_num_fields; schema_idx++) { + if (decoded_idx < num_decoded_fields && + decoded_field_indices[decoded_idx] == schema_idx) { + // This field needs to be decoded + auto const dt = all_types[schema_idx]; + auto const enc = encodings[decoded_idx]; + + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_varint_from_locations_kernel<<>>( + message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::UINT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_FIXED) { - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else if (enc == ENC_DEFAULT) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for UINT32 protobuf field"); + + case cudf::type_id::INT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } else { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::INT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_ZIGZAG) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else if (enc == ENC_FIXED) { - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else if (enc == ENC_DEFAULT) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for INT64 protobuf field"); + + case cudf::type_id::UINT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } else { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::UINT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_FIXED) { - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else if (enc == ENC_DEFAULT) { - extract_varint_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for UINT64 protobuf field"); + + case cudf::type_id::INT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } else { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::FLOAT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_DEFAULT) { - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for FLOAT32 protobuf field"); + + case cudf::type_id::UINT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } else { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::FLOAT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - if (enc == ENC_DEFAULT) { - extract_fixed_kernel<<>>( - *d_in, fn, out.data(), valid.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for FLOAT64 protobuf field"); + + case cudf::type_id::FLOAT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - children.push_back( - std::make_unique(dt, rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::STRING: { - rmm::device_uvector pairs(rows, stream, mr); - if (enc == ENC_DEFAULT) { - extract_string_kernel<<>>( - *d_in, fn, pairs.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for STRING protobuf field"); + + case cudf::type_id::FLOAT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, + num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; } - children.push_back( - cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr)); - break; - } - case cudf::type_id::LIST: { - // For protobuf `bytes` fields: we reuse the string extraction kernel to get the - // length-delimited raw bytes. The resulting strings column is then re-interpreted as - // LIST by extracting its internal offsets and char data (which is just raw bytes). - rmm::device_uvector pairs(rows, stream, mr); - if (enc == ENC_DEFAULT) { - extract_string_kernel<<>>( - *d_in, fn, pairs.data(), d_error.data()); - } else { - CUDF_FAIL("Unsupported encoding for LIST protobuf field"); + + case cudf::type_id::STRING: { + // Extract lengths and compute output offsets via prefix sum + rmm::device_uvector lengths(rows, stream, mr); + extract_lengths_kernel<<>>( + d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows); + + rmm::device_uvector output_offsets(rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); + + // Get total size + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + output_offsets.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + stream.synchronize(); + total_chars += last_len; + + // Set the final offset + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Allocate and copy character data + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_varlen_data_kernel<<>>( + message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + output_offsets.data(), + chars.data(), + rows); + } + + // Create validity mask (field found = valid) + rmm::device_uvector valid(rows, stream, mr); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(rows), + valid.begin(), + [locs = d_locations.data(), decoded_idx, num_decoded_fields] __device__(auto row) { + return locs[row * num_decoded_fields + decoded_idx].offset >= 0; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + // Create offsets column + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, + rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); + + // Create strings column using offsets + chars buffer + all_children[schema_idx] = cudf::make_strings_column( + rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + break; } - auto strings = - cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); - - // Use strings_column_view to get the underlying data - cudf::strings_column_view scv(*strings); - auto const null_count = strings->null_count(); - - // Get offsets - need to copy since we can't take ownership from a view - auto offsets_col = std::make_unique(scv.offsets(), stream, mr); - - // Get chars data as INT8 column - auto chars_data = scv.chars_begin(stream); - auto chars_size = static_cast(scv.chars_size(stream)); - rmm::device_uvector chars_vec(chars_size, stream, mr); - if (chars_size > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync( - chars_vec.data(), chars_data, chars_size, cudaMemcpyDeviceToDevice, stream.value())); + + case cudf::type_id::LIST: { + // For protobuf bytes: create LIST directly (optimization #2) + // Extract lengths and compute output offsets via prefix sum + rmm::device_uvector lengths(rows, stream, mr); + extract_lengths_kernel<<>>( + d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows); + + rmm::device_uvector output_offsets(rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); + + // Get total size + int32_t total_bytes = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, + output_offsets.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + stream.synchronize(); + total_bytes += last_len; + + // Set the final offset + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Allocate and copy byte data directly to INT8 buffer + rmm::device_uvector child_data(total_bytes, stream, mr); + if (total_bytes > 0) { + copy_varlen_data_kernel<<>>( + message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + output_offsets.data(), + reinterpret_cast(child_data.data()), + rows); + } + + // Create validity mask + rmm::device_uvector valid(rows, stream, mr); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(rows), + valid.begin(), + [locs = d_locations.data(), decoded_idx, num_decoded_fields] __device__(auto row) { + return locs[row * num_decoded_fields + decoded_idx].offset >= 0; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + // Create offsets column + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, + rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); + + // Create INT8 child column directly (no intermediate strings column!) + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::INT8}, + total_bytes, + child_data.release(), + rmm::device_buffer{}, + 0); + + all_children[schema_idx] = cudf::make_lists_column(rows, + std::move(offsets_col), + std::move(child_col), + null_count, + std::move(mask), + stream, + mr); + break; } - // Create INT8 column from chars data - auto child_col = - std::make_unique(cudf::data_type{cudf::type_id::INT8}, - chars_size, - chars_vec.release(), - rmm::device_buffer{}, // no null mask for chars - 0); // no nulls - - // Get null mask - auto null_mask = cudf::copy_bitmask(*strings, stream, mr); - - children.push_back(cudf::make_lists_column(rows, - std::move(offsets_col), - std::move(child_col), - null_count, - std::move(null_mask), - stream, - mr)); - break; + + default: CUDF_FAIL("Unsupported output type for protobuf decoder"); } - default: CUDF_FAIL("Unsupported output type for protobuf decoder"); + + decoded_idx++; + } else { + // This field is not decoded - create null column + all_children[schema_idx] = make_null_column(all_types[schema_idx], rows, stream, mr); } } - // Check for kernel launch errors + // Check for errors CUDF_CUDA_TRY(cudaPeekAtLastError()); // Check for any parse errors. @@ -633,17 +966,10 @@ std::unique_ptr decode_protobuf_to_struct( CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); } - // Note: We intentionally do NOT propagate input nulls to the output STRUCT validity. - // The expected semantics for this low-level helper (see ProtobufTest) are: - // - The STRUCT row is always valid (non-null) - // - Individual children are null if the input message is null or the field is missing - // - // Higher-level Spark expressions can still apply their own null semantics if desired. + // Build the final struct rmm::device_buffer struct_mask{0, stream, mr}; - auto const struct_null_count = 0; - return cudf::make_structs_column( - rows, std::move(children), struct_null_count, std::move(struct_mask), stream, mr); + rows, std::move(all_children), 0, std::move(struct_mask), stream, mr); } } // namespace spark_rapids_jni diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 15c27e7f0a..96c6ae5f8b 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -25,11 +25,21 @@ namespace spark_rapids_jni { +// Encoding constants +constexpr int ENC_DEFAULT = 0; +constexpr int ENC_FIXED = 1; +constexpr int ENC_ZIGZAG = 2; + /** * Decode protobuf messages (one message per row) from a LIST column into a STRUCT * column. * - * This is intentionally limited to top-level scalar fields. + * This uses a two-pass approach for efficiency: + * - Pass 1: Scan all messages once, recording (offset, length) for each requested field + * - Pass 2: Extract data in parallel using the recorded locations + * + * This is significantly faster than the per-field approach when decoding multiple fields, + * as each message is only parsed once regardless of the number of fields. * * Supported output child types (cudf dtypes) and corresponding protobuf field types: * - BOOL8 : protobuf `bool` (varint wire type) @@ -50,20 +60,24 @@ namespace spark_rapids_jni { * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. * * @param binary_input LIST column, each row is one protobuf message - * @param field_numbers protobuf field numbers (one per output child) - * @param out_types output cudf data types (one per output child) - * @param encodings encoding type for each field (0=default, 1=fixed, 2=zigzag) - * @param fail_on_errors whether to throw on malformed messages. Note: error checking is performed - * after all kernels complete (not between kernel launches) to avoid synchronization - * overhead. If an error is detected, all kernels will have executed but an exception will be - * thrown. - * @return STRUCT column with the given children types; the STRUCT itself is always non-null, - * and individual child fields may be null when input message is null or field is missing + * @param total_num_fields Total number of fields in the output struct (including null columns) + * @param decoded_field_indices Indices into the output struct for fields that should be decoded. + * Fields not in this list will be null columns in the output. + * @param field_numbers Protobuf field numbers for decoded fields (parallel to decoded_field_indices) + * @param all_types Output cudf data types for ALL fields in the struct (size = total_num_fields) + * @param encodings Encoding type for each decoded field (0=default, 1=fixed, 2=zigzag) + * (parallel to decoded_field_indices) + * @param fail_on_errors Whether to throw on malformed messages. Note: error checking is performed + * after all kernels complete (not between kernel launches) to avoid synchronization overhead. + * @return STRUCT column with total_num_fields children. Decoded fields contain the parsed data, + * other fields contain all nulls. The STRUCT itself is always non-null. */ std::unique_ptr decode_protobuf_to_struct( cudf::column_view const& binary_input, + int total_num_fields, + std::vector const& decoded_field_indices, std::vector const& field_numbers, - std::vector const& out_types, + std::vector const& all_types, std::vector const& encodings, bool fail_on_errors); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index a091bdbcbd..4b419aa60c 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -23,9 +23,16 @@ /** * GPU protobuf decoding utilities. * - * This API is intentionally limited to top-level scalar fields whose - * values can be represented by a single cuDF scalar type. Supported protobuf field types - * include scalar fields using the standard protobuf wire encodings: + * This API uses a two-pass approach for efficient decoding: + *

      + *
    • Pass 1: Scan all messages once, recording (offset, length) for each requested field
    • + *
    • Pass 2: Extract data in parallel using the recorded locations
    • + *
    + * + * This is significantly faster than per-field parsing when decoding multiple fields, + * as each message is only parsed once regardless of the number of fields. + * + * Supported protobuf field types include scalar fields using the standard wire encodings: *
      *
    • VARINT: {@code int32}, {@code int64}, {@code uint32}, {@code uint64}, {@code bool}
    • *
    • ZIGZAG VARINT (encoding=2): {@code sint32}, {@code sint64}
    • @@ -33,9 +40,7 @@ *
    • FIXED64 (encoding=1): {@code fixed64}, {@code sfixed64}, {@code double}
    • *
    • LENGTH_DELIMITED: {@code string}, {@code bytes}
    • *
    - * Each decoded field becomes a child column of the resulting STRUCT, with its cuDF type - * specified via the corresponding {@code typeIds} entry. - *

    + * * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. */ public class Protobuf { @@ -48,77 +53,127 @@ public class Protobuf { public static final int ENC_ZIGZAG = 2; /** - * Decode a protobuf message-per-row binary column into a single STRUCT column. + * Decode a protobuf message-per-row binary column into a STRUCT column. + * + * This method supports schema projection: only the fields specified in + * {@code decodedFieldIndices} will be decoded. Other fields in the output + * struct will contain all null values. * * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param fieldNumbers protobuf field numbers to decode (one per struct child) - * @param typeIds cudf native type ids (one per struct child) - * @param typeScales encoding info or decimal scales: - * - For non-decimal types, this is the encoding: 0=default, 1=fixed, 2=zigzag. - * - For decimal types, this is the scale (currently unsupported). - * @return a cudf STRUCT column where children correspond 1:1 with {@code fieldNumbers}/{@code typeIds}. + * @param totalNumFields Total number of fields in the output struct (including null columns). + * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. + * These must be sorted in ascending order. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). + * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). + * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): + * 0=default (varint), 1=fixed, 2=zigzag. + * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, + * other fields contain all nulls. */ public static ColumnVector decodeToStruct(ColumnView binaryInput, + int totalNumFields, + int[] decodedFieldIndices, int[] fieldNumbers, - int[] typeIds, - int[] typeScales) { - return decodeToStruct(binaryInput, fieldNumbers, typeIds, typeScales, true); + int[] allTypeIds, + int[] encodings) { + return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, + allTypeIds, encodings, true); } /** - * Decode a protobuf message-per-row binary column into a single STRUCT column. + * Decode a protobuf message-per-row binary column into a STRUCT column. + * + * This method supports schema projection: only the fields specified in + * {@code decodedFieldIndices} will be decoded. Other fields in the output + * struct will contain all null values. * * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param fieldNumbers protobuf field numbers to decode (one per struct child) - * @param typeIds cudf native type ids (one per struct child) - * @param typeScales encoding info or decimal scales: - * - For non-decimal types, this is the encoding: 0=default, 1=fixed, 2=zigzag. - * - For decimal types, this is the scale (currently unsupported). + * @param totalNumFields Total number of fields in the output struct (including null columns). + * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. + * These must be sorted in ascending order. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). + * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). + * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): + * 0=default (varint), 1=fixed, 2=zigzag. * @param failOnErrors if true, throw an exception on malformed protobuf messages. * If false, return nulls for fields that cannot be parsed. * Note: error checking is performed after all fields are processed, * not between fields, to avoid synchronization overhead. - * @return a cudf STRUCT column where children correspond 1:1 with {@code fieldNumbers}/{@code typeIds}. + * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, + * other fields contain all nulls. */ public static ColumnVector decodeToStruct(ColumnView binaryInput, + int totalNumFields, + int[] decodedFieldIndices, int[] fieldNumbers, - int[] typeIds, - int[] typeScales, + int[] allTypeIds, + int[] encodings, boolean failOnErrors) { - if (fieldNumbers == null || typeIds == null || typeScales == null) { - throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be non-null"); + // Parameter validation + if (decodedFieldIndices == null || fieldNumbers == null || + allTypeIds == null || encodings == null) { + throw new IllegalArgumentException("Arrays must be non-null"); + } + if (totalNumFields < 0) { + throw new IllegalArgumentException("totalNumFields must be non-negative"); + } + if (allTypeIds.length != totalNumFields) { + throw new IllegalArgumentException( + "allTypeIds length (" + allTypeIds.length + ") must equal totalNumFields (" + + totalNumFields + ")"); } - if (fieldNumbers.length == 0) { - throw new IllegalArgumentException("fieldNumbers must not be empty"); + if (decodedFieldIndices.length != fieldNumbers.length || + decodedFieldIndices.length != encodings.length) { + throw new IllegalArgumentException( + "decodedFieldIndices/fieldNumbers/encodings must be the same length"); } - if (fieldNumbers.length != typeIds.length || fieldNumbers.length != typeScales.length) { - throw new IllegalArgumentException("fieldNumbers/typeIds/typeScales must be the same length"); + + // Validate decoded field indices are in bounds and sorted + int prevIdx = -1; + for (int i = 0; i < decodedFieldIndices.length; i++) { + int idx = decodedFieldIndices[i]; + if (idx < 0 || idx >= totalNumFields) { + throw new IllegalArgumentException( + "Invalid decoded field index at position " + i + ": " + idx + + " (must be in range [0, " + totalNumFields + "))"); + } + if (idx <= prevIdx) { + throw new IllegalArgumentException( + "decodedFieldIndices must be sorted in ascending order without duplicates"); + } + prevIdx = idx; } - // Validate field numbers are positive (protobuf field numbers must be 1-536870911) + + // Validate field numbers are positive for (int i = 0; i < fieldNumbers.length; i++) { if (fieldNumbers[i] <= 0) { throw new IllegalArgumentException( - "Invalid field number at index " + i + ": " + fieldNumbers[i] - + " (field numbers must be positive)"); + "Invalid field number at index " + i + ": " + fieldNumbers[i] + + " (field numbers must be positive)"); } } - // Validate encoding values are within valid range - for (int i = 0; i < typeScales.length; i++) { - int enc = typeScales[i]; + + // Validate encoding values + for (int i = 0; i < encodings.length; i++) { + int enc = encodings[i]; if (enc < ENC_DEFAULT || enc > ENC_ZIGZAG) { throw new IllegalArgumentException( - "Invalid encoding value at index " + i + ": " + enc - + " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); + "Invalid encoding value at index " + i + ": " + enc + + " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); } } - long handle = decodeToStruct(binaryInput.getNativeView(), fieldNumbers, typeIds, typeScales, failOnErrors); + + long handle = decodeToStruct(binaryInput.getNativeView(), totalNumFields, + decodedFieldIndices, fieldNumbers, allTypeIds, + encodings, failOnErrors); return new ColumnVector(handle); } private static native long decodeToStruct(long binaryInputView, + int totalNumFields, + int[] decodedFieldIndices, int[] fieldNumbers, - int[] typeIds, - int[] typeScales, + int[] allTypeIds, + int[] encodings, boolean failOnErrors); } - diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index c2d1be0a35..b3ed38b59a 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025, NVIDIA CORPORATION. + * Copyright (c) 2025-2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; +import ai.rapids.cudf.ColumnView; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector.*; import ai.rapids.cudf.Table; @@ -141,6 +142,40 @@ private static Byte[] concat(Byte[]... parts) { return out; } + // ============================================================================ + // Helper methods for calling the new API + // ============================================================================ + + /** + * Helper method that wraps the new API for tests that decode all fields. + * This simulates the old API behavior where all fields are decoded. + */ + private static ColumnVector decodeAllFields(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings) { + return decodeAllFields(binaryInput, fieldNumbers, typeIds, encodings, true); + } + + /** + * Helper method that wraps the new API for tests that decode all fields. + * This simulates the old API behavior where all fields are decoded. + */ + private static ColumnVector decodeAllFields(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + // When decoding all fields, decodedFieldIndices is [0, 1, 2, ..., n-1] + int[] decodedFieldIndices = new int[numFields]; + for (int i = 0; i < numFields; i++) { + decodedFieldIndices[i] = i; + } + return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, + fieldNumbers, typeIds, encodings, failOnErrors); + } + // ============================================================================ // Basic Type Tests // ============================================================================ @@ -168,7 +203,7 @@ void decodeVarintAndStringToStruct() { ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1, 2}, new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, @@ -198,7 +233,7 @@ void decodeMoreTypes() { ColumnVector expectedB = ColumnVector.fromLists( new ListType(true, new BasicType(true, DType.INT8)), Arrays.asList((byte) 1, (byte) 2, (byte) 3)); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1, 2, 3, 4}, new int[]{ @@ -237,7 +272,7 @@ void decodeFloatDoubleAndBool() { ColumnVector expectedFloat = ColumnVector.fromBoxedFloats(3.14f, -1.5f); ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(2.71828, 0.0); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedBool, expectedFloat, expectedDouble); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1, 2, 3}, new int[]{ @@ -249,6 +284,64 @@ void decodeFloatDoubleAndBool() { } } + // ============================================================================ + // Schema Projection Tests (new API feature) + // ============================================================================ + + @Test + void testSchemaProjection() { + // message Msg { int64 f1 = 1; string f2 = 2; int32 f3 = 3; } + // Only decode f1 and f3, f2 should be null + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100)), + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes()), + box(tag(3, WT_VARINT)), box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); + // Expected: f1=100, f2=null (not decoded), f3=42 + ColumnVector expectedF1 = ColumnVector.fromBoxedLongs(100L); + ColumnVector expectedF2 = ColumnVector.fromStrings((String)null); + ColumnVector expectedF3 = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedF1, expectedF2, expectedF3); + // Decode only fields at indices 0 and 2 (skip index 1) + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + 3, // total fields + new int[]{0, 2}, // decode only indices 0 and 2 + new int[]{1, 3}, // field numbers for decoded fields + new int[]{DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, // types for ALL fields + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, // encodings for decoded fields + true)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSchemaProjectionDecodeNone() { + // Decode no fields - all should be null + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100)), + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); + ColumnVector expectedF1 = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedF2 = ColumnVector.fromStrings((String)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedF1, expectedF2); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + 2, // total fields + new int[]{}, // decode no fields + new int[]{}, // no field numbers + new int[]{DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, // types for ALL fields + new int[]{}, // no encodings + true)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + // ============================================================================ // Varint Boundary Tests // ============================================================================ @@ -263,7 +356,7 @@ void testVarintMaxUint64() { (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0x01}); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.UINT64.getTypeId().getNativeId()}, @@ -284,7 +377,7 @@ void testVarintZero() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(0L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -305,7 +398,7 @@ void testVarintOverEncodedZero() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(0L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -324,7 +417,7 @@ void testVarint10thByteInvalid() { (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0x02}); // 0x02 has 2nd bit set try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -353,7 +446,7 @@ void testZigzagInt32Min() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedInts(minInt32); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT32.getTypeId().getNativeId()}, @@ -374,7 +467,7 @@ void testZigzagInt32Max() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedInts(maxInt32); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT32.getTypeId().getNativeId()}, @@ -394,7 +487,7 @@ void testZigzagInt64Min() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedLong = ColumnVector.fromBoxedLongs(minInt64); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -413,7 +506,7 @@ void testZigzagInt64Max() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedLong = ColumnVector.fromBoxedLongs(maxInt64); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -432,7 +525,7 @@ void testZigzagNegativeOne() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedLong = ColumnVector.fromBoxedLongs(-1L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -452,7 +545,7 @@ void testMalformedVarint() { (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -470,7 +563,7 @@ void testTruncatedVarint() { // Single byte with continuation bit set but no following byte Byte[] truncated = concat(box(tag(1, WT_VARINT)), new Byte[]{(byte)0x80}); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -488,7 +581,7 @@ void testTruncatedLengthDelimited() { // String field with length=5 but no actual data Byte[] truncated = concat(box(tag(2, WT_LEN)), box(encodeVarint(5))); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{2}, new int[]{DType.STRING.getTypeId().getNativeId()}, @@ -506,7 +599,7 @@ void testTruncatedFixed32() { // Fixed32 needs 4 bytes but only 3 provided Byte[] truncated = concat(box(tag(1, WT_32BIT)), new Byte[]{0x01, 0x02, 0x03}); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT32.getTypeId().getNativeId()}, @@ -525,7 +618,7 @@ void testTruncatedFixed64() { Byte[] truncated = concat(box(tag(1, WT_64BIT)), new Byte[]{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -546,7 +639,7 @@ void testPartialLengthDelimitedData() { box(encodeVarint(10)), box("hello".getBytes())); // only 5 bytes try (Table input = new Table.TestBuilder().column(new Byte[][]{partial}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.STRING.getTypeId().getNativeId()}, @@ -570,7 +663,7 @@ void testWrongWireType() { box(tag(1, WT_32BIT)), // wire type 5 instead of 0 box(encodeFixed32(100))); try (Table input = new Table.TestBuilder().column(new Byte[][]{wrongType}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, // expects varint @@ -590,7 +683,7 @@ void testWrongWireTypeForString() { box(tag(1, WT_VARINT)), box(encodeVarint(12345))); try (Table input = new Table.TestBuilder().column(new Byte[][]{wrongType}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.STRING.getTypeId().getNativeId()}, // expects LEN @@ -619,7 +712,7 @@ void testSkipUnknownVarintField() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -640,7 +733,7 @@ void testSkipUnknownFixed64Field() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -662,7 +755,7 @@ void testSkipUnknownLengthDelimitedField() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -683,7 +776,7 @@ void testSkipUnknownFixed32Field() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -707,7 +800,7 @@ void testLastOneWins() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedInt = ColumnVector.fromBoxedLongs(300L); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -727,7 +820,7 @@ void testLastOneWinsForString() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); ColumnVector expectedStr = ColumnVector.fromStrings("last"); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.STRING.getTypeId().getNativeId()}, @@ -747,7 +840,7 @@ void testFailOnErrorsTrue() { (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -763,7 +856,7 @@ void testFieldNumberZeroInvalid() { // Field number 0 is reserved and invalid Byte[] invalid = concat(box(tag(0, WT_VARINT)), box(encodeVarint(123))); try (Table input = new Table.TestBuilder().column(new Byte[][]{invalid}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -784,7 +877,7 @@ void testEmptyMessage() { ColumnVector expectedInt = ColumnVector.fromBoxedLongs((Long)null); ColumnVector expectedStr = ColumnVector.fromStrings((String)null); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt, expectedStr); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1, 2}, new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, @@ -810,7 +903,7 @@ void testFloatSpecialValues() { Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, Float.NaN, Float.MIN_VALUE, Float.MAX_VALUE); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedFloat); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.FLOAT32.getTypeId().getNativeId()}, @@ -832,7 +925,7 @@ void testDoubleSpecialValues() { Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NaN, Double.MIN_VALUE, Double.MAX_VALUE); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedDouble); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.FLOAT64.getTypeId().getNativeId()}, @@ -906,7 +999,7 @@ void testFailfastMalformedVarint() { (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -923,7 +1016,7 @@ void testFailfastTruncatedVarint() { Byte[] truncated = concat(box(tag(1, WT_VARINT)), new Byte[]{(byte)0x80}); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -940,7 +1033,7 @@ void testFailfastTruncatedString() { Byte[] truncated = concat(box(tag(2, WT_LEN)), box(encodeVarint(5))); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{2}, new int[]{DType.STRING.getTypeId().getNativeId()}, @@ -957,7 +1050,7 @@ void testFailfastTruncatedFixed32() { Byte[] truncated = concat(box(tag(1, WT_32BIT)), new Byte[]{0x01, 0x02, 0x03}); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT32.getTypeId().getNativeId()}, @@ -974,7 +1067,7 @@ void testFailfastTruncatedFixed64() { Byte[] truncated = concat(box(tag(1, WT_64BIT)), new Byte[]{0x01, 0x02, 0x03, 0x04, 0x05}); try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -991,7 +1084,7 @@ void testFailfastWrongWireType() { Byte[] row = concat(box(tag(1, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes())); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -1008,7 +1101,7 @@ void testFailfastFieldNumberZero() { Byte[] row = concat(box(tag(0, WT_VARINT)), box(encodeVarint(42))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -1024,7 +1117,7 @@ void testFailfastValidDataDoesNotThrow() { // Valid protobuf should not throw even with failOnErrors = true Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, @@ -1036,4 +1129,45 @@ void testFailfastValidDataDoesNotThrow() { } } } + + // ============================================================================ + // Performance Benchmark Tests (Multi-field) + // ============================================================================ + + @Test + void testMultiFieldPerformance() { + // Test with 6 fields to verify fused kernel efficiency + // message Msg { bool f1=1; int32 f2=2; int64 f3=3; float f4=4; double f5=5; string f6=6; } + Byte[] row = concat( + box(tag(1, WT_VARINT)), new Byte[]{0x01}, + box(tag(2, WT_VARINT)), box(encodeVarint(12345)), + box(tag(3, WT_VARINT)), box(encodeVarint(9876543210L)), + box(tag(4, WT_32BIT)), box(encodeFloat(3.14f)), + box(tag(5, WT_64BIT)), box(encodeDouble(2.71828)), + box(tag(6, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2, 3, 4, 5, 6}, + new int[]{ + DType.BOOL8.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId(), + DType.FLOAT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0, 0, 0, 0, 0})) { + try (ColumnVector expectedBool = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(12345); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(9876543210L); + ColumnVector expectedFloat = ColumnVector.fromBoxedFloats(3.14f); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(2.71828); + ColumnVector expectedString = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct( + expectedBool, expectedInt, expectedLong, expectedFloat, expectedDouble, expectedString)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + } } From ecd3a38dfb0dfafc389cc4e9c415086b2fe3a251 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 20 Jan 2026 20:14:56 +0800 Subject: [PATCH 012/107] fix merge Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index e0d2a087b7..27ea2c43fc 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -581,19 +581,6 @@ std::unique_ptr decode_protobuf_to_struct( rows, std::move(null_children), 0, rmm::device_buffer{}, stream, mr); } - auto rows = binary_input.size(); - - // Handle zero-row case explicitly - return empty STRUCT with properly typed children - if (rows == 0) { - std::vector> empty_children; - empty_children.reserve(out_types.size()); - for (auto const& dt : out_types) { - empty_children.push_back(cudf::make_empty_column(dt)); - } - return cudf::make_structs_column( - 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); - } - auto d_in = cudf::column_device_view::create(binary_input, stream); // Prepare field descriptors for the scanning kernel From c86f78cdc4fa993f8d5a3ad68f58bfe4065f5178 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 27 Jan 2026 16:14:28 +0800 Subject: [PATCH 013/107] Support enum, required and default values Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 86 +- src/main/cpp/src/protobuf.cu | 355 +++++++- src/main/cpp/src/protobuf.hpp | 22 +- .../com/nvidia/spark/rapids/jni/Protobuf.java | 172 +++- .../nvidia/spark/rapids/jni/ProtobufTest.java | 854 +++++++++++++++++- 5 files changed, 1440 insertions(+), 49 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index d2b20923ab..37f4b16775 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -32,6 +32,13 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, jintArray field_numbers, jintArray all_type_ids, jintArray encodings, + jbooleanArray is_required, + jbooleanArray has_default_value, + jlongArray default_ints, + jdoubleArray default_floats, + jbooleanArray default_bools, + jobjectArray default_strings, + jobjectArray enum_valid_values, jboolean fail_on_errors) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); @@ -39,6 +46,13 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); JNI_NULL_CHECK(env, all_type_ids, "all_type_ids is null", 0); JNI_NULL_CHECK(env, encodings, "encodings is null", 0); + JNI_NULL_CHECK(env, is_required, "is_required is null", 0); + JNI_NULL_CHECK(env, has_default_value, "has_default_value is null", 0); + JNI_NULL_CHECK(env, default_ints, "default_ints is null", 0); + JNI_NULL_CHECK(env, default_floats, "default_floats is null", 0); + JNI_NULL_CHECK(env, default_bools, "default_bools is null", 0); + JNI_NULL_CHECK(env, default_strings, "default_strings is null", 0); + JNI_NULL_CHECK(env, enum_valid_values, "enum_valid_values is null", 0); JNI_TRY { @@ -49,13 +63,25 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, cudf::jni::native_jintArray n_field_numbers(env, field_numbers); cudf::jni::native_jintArray n_all_type_ids(env, all_type_ids); cudf::jni::native_jintArray n_encodings(env, encodings); + cudf::jni::native_jbooleanArray n_is_required(env, is_required); + cudf::jni::native_jbooleanArray n_has_default(env, has_default_value); + cudf::jni::native_jlongArray n_default_ints(env, default_ints); + cudf::jni::native_jdoubleArray n_default_floats(env, default_floats); + cudf::jni::native_jbooleanArray n_default_bools(env, default_bools); + + int num_decoded_fields = n_decoded_indices.size(); // Validate array sizes - if (n_decoded_indices.size() != n_field_numbers.size() || - n_decoded_indices.size() != n_encodings.size()) { + if (n_field_numbers.size() != num_decoded_fields || + n_encodings.size() != num_decoded_fields || + n_is_required.size() != num_decoded_fields || + n_has_default.size() != num_decoded_fields || + n_default_ints.size() != num_decoded_fields || + n_default_floats.size() != num_decoded_fields || + n_default_bools.size() != num_decoded_fields) { JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "decoded_field_indices/field_numbers/encodings must be the same length", + "All decoded field arrays must have the same length", 0); } if (n_all_type_ids.size() != total_num_fields) { @@ -68,6 +94,56 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, std::vector decoded_indices(n_decoded_indices.begin(), n_decoded_indices.end()); std::vector field_nums(n_field_numbers.begin(), n_field_numbers.end()); std::vector encs(n_encodings.begin(), n_encodings.end()); + + // Convert jboolean arrays to std::vector + std::vector required_flags; + std::vector has_default_flags; + std::vector default_bool_values; + required_flags.reserve(num_decoded_fields); + has_default_flags.reserve(num_decoded_fields); + default_bool_values.reserve(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; ++i) { + required_flags.push_back(n_is_required[i] != 0); + has_default_flags.push_back(n_has_default[i] != 0); + default_bool_values.push_back(n_default_bools[i] != 0); + } + + // Convert default int/float values + std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); + std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); + + // Convert default string values (byte[][] -> vector>) + std::vector> default_string_values; + default_string_values.reserve(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; ++i) { + jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(default_strings, i)); + if (byte_arr == nullptr) { + default_string_values.emplace_back(); // empty vector for null + } else { + jsize len = env->GetArrayLength(byte_arr); + jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + default_string_values.emplace_back( + reinterpret_cast(bytes), + reinterpret_cast(bytes) + len); + env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + } + } + + // Convert enum valid values (int[][] -> vector>) + // Each element is either null (not an enum field) or an array of valid enum values + std::vector> enum_values; + enum_values.reserve(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; ++i) { + jintArray int_arr = static_cast(env->GetObjectArrayElement(enum_valid_values, i)); + if (int_arr == nullptr) { + enum_values.emplace_back(); // empty vector for null (not an enum field) + } else { + jsize len = env->GetArrayLength(int_arr); + jint* ints = env->GetIntArrayElements(int_arr, nullptr); + enum_values.emplace_back(ints, ints + len); + env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + } + } // Build all_types vector - types for ALL fields in the output struct std::vector all_types; @@ -78,7 +154,9 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } auto result = spark_rapids_jni::decode_protobuf_to_struct( - *input, total_num_fields, decoded_indices, field_nums, all_types, encs, fail_on_errors); + *input, total_num_fields, decoded_indices, field_nums, all_types, encs, + required_flags, has_default_flags, default_int_values, default_float_values, + default_bool_values, default_string_values, enum_values, fail_on_errors); return cudf::jni::release_as_jlong(result); } JNI_CATCH(env, 0); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 27ea2c43fc..a886dde4cc 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -277,6 +277,7 @@ __global__ void scan_all_fields_kernel( /** * Extract varint field data using pre-recorded locations. + * Supports default values for missing fields. */ template __global__ void extract_varint_from_locations_kernel( @@ -289,14 +290,22 @@ __global__ void extract_varint_from_locations_kernel( OutT* out, bool* valid, int num_rows, - int* error_flag) + int* error_flag, + bool has_default = false, + int64_t default_value = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto loc = locations[row * num_fields + field_idx]; if (loc.offset < 0) { - valid[row] = false; + // Field not found - use default value if available + if (has_default) { + out[row] = static_cast(default_value); + valid[row] = true; + } else { + valid[row] = false; + } return; } @@ -320,6 +329,7 @@ __global__ void extract_varint_from_locations_kernel( /** * Extract fixed-size field data (fixed32, fixed64, float, double). + * Supports default values for missing fields. */ template __global__ void extract_fixed_from_locations_kernel( @@ -332,14 +342,22 @@ __global__ void extract_fixed_from_locations_kernel( OutT* out, bool* valid, int num_rows, - int* error_flag) + int* error_flag, + bool has_default = false, + OutT default_value = OutT{}) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto loc = locations[row * num_fields + field_idx]; if (loc.offset < 0) { - valid[row] = false; + // Field not found - use default value if available + if (has_default) { + out[row] = default_value; + valid[row] = true; + } else { + valid[row] = false; + } return; } @@ -372,6 +390,7 @@ __global__ void extract_fixed_from_locations_kernel( /** * Kernel to copy variable-length data (string/bytes) to output buffer. * Uses pre-computed output offsets from prefix sum. + * Supports default values for missing fields. */ __global__ void copy_varlen_data_kernel( uint8_t const* message_data, @@ -382,17 +401,31 @@ __global__ void copy_varlen_data_kernel( int num_fields, int32_t const* output_offsets, // Pre-computed output offsets (prefix sum) char* output_data, - int num_rows) + int num_rows, + bool has_default = false, + uint8_t const* default_data = nullptr, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto loc = locations[row * num_fields + field_idx]; - if (loc.offset < 0 || loc.length == 0) return; + char* dst = output_data + output_offsets[row]; + + if (loc.offset < 0) { + // Field not found - use default if available + if (has_default && default_length > 0) { + for (int i = 0; i < default_length; i++) { + dst[i] = static_cast(default_data[i]); + } + } + return; + } + + if (loc.length == 0) return; auto row_start = input_offsets[row] - base_offset; uint8_t const* src = message_data + row_start + loc.offset; - char* dst = output_data + output_offsets[row]; // Copy data for (int i = 0; i < loc.length; i++) { @@ -402,19 +435,28 @@ __global__ void copy_varlen_data_kernel( /** * Kernel to extract lengths from locations for prefix sum. + * Supports default values for missing fields. */ __global__ void extract_lengths_kernel( field_location const* locations, int field_idx, int num_fields, int32_t* lengths, - int num_rows) + int num_rows, + bool has_default = false, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; - auto loc = locations[row * num_fields + field_idx]; - lengths[row] = (loc.offset >= 0) ? loc.length : 0; + auto loc = locations[row * num_fields + field_idx]; + if (loc.offset >= 0) { + lengths[row] = loc.length; + } else if (has_default) { + lengths[row] = default_length; + } else { + lengths[row] = 0; + } } // ============================================================================ @@ -530,6 +572,85 @@ std::unique_ptr make_null_column( } // namespace +// ============================================================================ +// Kernel to check required fields after scan pass +// ============================================================================ + +/** + * Check if any required fields are missing (offset < 0) and set error flag. + * This is called after the scan pass to validate required field constraints. + */ +__global__ void check_required_fields_kernel( + field_location const* locations, // [num_rows * num_fields] + uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) + int num_fields, + int num_rows, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + for (int f = 0; f < num_fields; f++) { + if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { + // Required field is missing - set error flag + atomicExch(error_flag, 1); + return; // No need to check other fields for this row + } + } +} + +/** + * Validate enum values against a set of valid values. + * If a value is not in the valid set: + * 1. Mark the field as invalid (valid[row] = false) + * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) + * + * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is + * encountered, the entire struct row is set to null (not just the enum field). + * + * The valid_values array must be sorted for binary search. + */ +__global__ void validate_enum_values_kernel( + int32_t const* values, // [num_rows] extracted enum values + bool* valid, // [num_rows] field validity flags (will be modified) + bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) + int32_t const* valid_enum_values, // sorted array of valid enum values + int num_valid_values, // size of valid_enum_values + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + // Skip if already invalid (field was missing) - missing field is not an enum error + if (!valid[row]) return; + + int32_t val = values[row]; + + // Binary search for the value in valid_enum_values + int left = 0; + int right = num_valid_values - 1; + bool found = false; + + while (left <= right) { + int mid = left + (right - left) / 2; + if (valid_enum_values[mid] == val) { + found = true; + break; + } else if (valid_enum_values[mid] < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // If not found, mark as invalid + if (!found) { + valid[row] = false; + // Also mark the row as having an invalid enum - this will null the entire struct row + row_has_invalid_enum[row] = true; + } +} + namespace spark_rapids_jni { std::unique_ptr decode_protobuf_to_struct( @@ -539,6 +660,13 @@ std::unique_ptr decode_protobuf_to_struct( std::vector const& field_numbers, std::vector const& all_types, std::vector const& encodings, + std::vector const& is_required, + std::vector const& has_default_value, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, bool fail_on_errors) { CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, @@ -553,6 +681,18 @@ std::unique_ptr decode_protobuf_to_struct( "decoded_field_indices and field_numbers must have the same length"); CUDF_EXPECTS(encodings.size() == field_numbers.size(), "encodings and field_numbers must have the same length"); + CUDF_EXPECTS(is_required.size() == field_numbers.size(), + "is_required and field_numbers must have the same length"); + CUDF_EXPECTS(has_default_value.size() == field_numbers.size(), + "has_default_value and field_numbers must have the same length"); + CUDF_EXPECTS(default_ints.size() == field_numbers.size(), + "default_ints and field_numbers must have the same length"); + CUDF_EXPECTS(default_floats.size() == field_numbers.size(), + "default_floats and field_numbers must have the same length"); + CUDF_EXPECTS(default_bools.size() == field_numbers.size(), + "default_bools and field_numbers must have the same length"); + CUDF_EXPECTS(default_strings.size() == field_numbers.size(), + "default_strings and field_numbers must have the same length"); auto const stream = cudf::get_default_stream(); auto mr = cudf::get_current_device_resource_ref(); @@ -607,6 +747,18 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + // Check if any field has enum validation + bool has_enum_fields = std::any_of(enum_valid_values.begin(), enum_valid_values.end(), + [](auto const& v) { return !v.empty(); }); + + // Track rows with invalid enum values (used to null entire struct row) + // This matches Spark CPU PERMISSIVE mode behavior + rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? rows : 0, stream, mr); + if (has_enum_fields) { + // Initialize all to false (no invalid enums yet) + CUDF_CUDA_TRY(cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, rows * sizeof(bool), stream.value())); + } + auto const threads = 256; auto const blocks = static_cast((rows + threads - 1) / threads); @@ -616,6 +768,30 @@ std::unique_ptr decode_protobuf_to_struct( scan_all_fields_kernel<<>>( *d_in, d_field_descs.data(), num_decoded_fields, d_locations.data(), d_error.data()); + // ========================================================================= + // Check required fields (after scan pass) + // ========================================================================= + // Only check if any field is required to avoid unnecessary kernel launch + bool has_required_fields = std::any_of(is_required.begin(), is_required.end(), + [](bool b) { return b; }); + if (has_required_fields) { + // Copy is_required flags to device + // Note: std::vector is special (bitfield), so we convert to uint8_t + rmm::device_uvector d_is_required(num_decoded_fields, stream, mr); + std::vector h_is_required_vec(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; i++) { + h_is_required_vec[i] = is_required[i] ? 1 : 0; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), + h_is_required_vec.data(), + num_decoded_fields * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + + check_required_fields_kernel<<>>( + d_locations.data(), d_is_required.data(), num_decoded_fields, rows, d_error.data()); + } + // Get message data pointer and offsets for pass 2 auto const* message_data = reinterpret_cast(in_list.child().data()); @@ -646,6 +822,8 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::BOOL8: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_val = has_def ? (default_bools[decoded_idx] ? 1 : 0) : 0; extract_varint_from_locations_kernel<<>>( message_data, list_offsets, @@ -656,7 +834,9 @@ std::unique_ptr decode_protobuf_to_struct( out.data(), valid.data(), rows, - d_error.data()); + d_error.data(), + has_def, + def_val); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -666,19 +846,44 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::INT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + int32_t def_fixed = static_cast(def_int); if (enc == spark_rapids_jni::ENC_ZIGZAG) { extract_varint_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { extract_fixed_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_fixed); } else { extract_varint_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } + + // Validate enum values if this is an enum field + // enum_valid_values[decoded_idx] is non-empty for enum fields + auto const& valid_enums = enum_valid_values[decoded_idx]; + if (!valid_enums.empty()) { + // Copy valid enum values to device + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Validate enum values - unknown values will null the entire row + validate_enum_values_kernel<<>>( + out.data(), valid.data(), d_row_has_invalid_enum.data(), + d_valid_enums.data(), static_cast(valid_enums.size()), rows); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -688,14 +893,19 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::UINT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + uint32_t def_fixed = static_cast(def_int); if (enc == spark_rapids_jni::ENC_FIXED) { extract_fixed_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_fixed); } else { extract_varint_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = @@ -706,18 +916,23 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::INT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { extract_varint_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { extract_fixed_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } else { extract_varint_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = @@ -728,14 +943,19 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::UINT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + uint64_t def_fixed = static_cast(def_int); if (enc == spark_rapids_jni::ENC_FIXED) { extract_fixed_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_fixed); } else { extract_varint_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = @@ -746,9 +966,12 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::FLOAT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + float def_float = has_def ? static_cast(default_floats[decoded_idx]) : 0.0f; extract_fixed_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_float); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -758,9 +981,12 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::FLOAT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + double def_double = has_def ? default_floats[decoded_idx] : 0.0; extract_fixed_from_locations_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data()); + num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), + has_def, def_double); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -768,10 +994,26 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::STRING: { + // Check for default value + bool has_def = has_default_value[decoded_idx]; + auto const& def_str = default_strings[decoded_idx]; + int32_t def_len = has_def ? static_cast(def_str.size()) : 0; + + // Copy default string to device if needed + rmm::device_uvector d_default_str(def_len, stream, mr); + if (has_def && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), + def_str.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); + } + // Extract lengths and compute output offsets via prefix sum rmm::device_uvector lengths(rows, stream, mr); extract_lengths_kernel<<>>( - d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows); + d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows, + has_def, def_len); rmm::device_uvector output_offsets(rows + 1, stream, mr); thrust::exclusive_scan( @@ -812,18 +1054,21 @@ std::unique_ptr decode_protobuf_to_struct( num_decoded_fields, output_offsets.data(), chars.data(), - rows); + rows, + has_def, + d_default_str.data(), + def_len); } - // Create validity mask (field found = valid) + // Create validity mask (field found OR has default = valid) rmm::device_uvector valid(rows, stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(rows), valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields] __device__(auto row) { - return locs[row * num_decoded_fields + decoded_idx].offset >= 0; + [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__(auto row) { + return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); @@ -843,10 +1088,26 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::LIST: { // For protobuf bytes: create LIST directly (optimization #2) + // Check for default value + bool has_def = has_default_value[decoded_idx]; + auto const& def_bytes = default_strings[decoded_idx]; + int32_t def_len = has_def ? static_cast(def_bytes.size()) : 0; + + // Copy default bytes to device if needed + rmm::device_uvector d_default_bytes(def_len, stream, mr); + if (has_def && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), + def_bytes.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); + } + // Extract lengths and compute output offsets via prefix sum rmm::device_uvector lengths(rows, stream, mr); extract_lengths_kernel<<>>( - d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows); + d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows, + has_def, def_len); rmm::device_uvector output_offsets(rows + 1, stream, mr); thrust::exclusive_scan( @@ -887,18 +1148,21 @@ std::unique_ptr decode_protobuf_to_struct( num_decoded_fields, output_offsets.data(), reinterpret_cast(child_data.data()), - rows); + rows, + has_def, + d_default_bytes.data(), + def_len); } - // Create validity mask + // Create validity mask (field found OR has default = valid) rmm::device_uvector valid(rows, stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(rows), valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields] __device__(auto row) { - return locs[row * num_decoded_fields + decoded_idx].offset >= 0; + [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__(auto row) { + return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); @@ -941,7 +1205,7 @@ std::unique_ptr decode_protobuf_to_struct( // Check for errors CUDF_CUDA_TRY(cudaPeekAtLastError()); - // Check for any parse errors. + // Check for any parse errors or missing required fields. // Note: We check errors after all kernels complete rather than between kernel launches // to avoid expensive synchronization overhead. If fail_on_errors is true and an error // occurred, all kernels will have executed but we throw an exception here. @@ -950,13 +1214,32 @@ std::unique_ptr decode_protobuf_to_struct( cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); + CUDF_EXPECTS(h_error == 0, + "Malformed protobuf message, unsupported wire type, or missing required field"); } // Build the final struct + // If any rows have invalid enum values, create a null mask for the struct + // This matches Spark CPU PERMISSIVE mode: unknown enum values null the entire row + cudf::size_type struct_null_count = 0; rmm::device_buffer struct_mask{0, stream, mr}; + + if (has_enum_fields) { + // Create struct null mask: row is valid if it has NO invalid enums + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(rows), + [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { + return !row_invalid[row]; // valid if NOT invalid + }, + stream, + mr); + struct_mask = std::move(mask); + struct_null_count = null_count; + } + return cudf::make_structs_column( - rows, std::move(all_children), 0, std::move(struct_mask), stream, mr); + rows, std::move(all_children), struct_null_count, std::move(struct_mask), stream, mr); } } // namespace spark_rapids_jni diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 96c6ae5f8b..817374dc66 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -67,8 +67,19 @@ constexpr int ENC_ZIGZAG = 2; * @param all_types Output cudf data types for ALL fields in the struct (size = total_num_fields) * @param encodings Encoding type for each decoded field (0=default, 1=fixed, 2=zigzag) * (parallel to decoded_field_indices) - * @param fail_on_errors Whether to throw on malformed messages. Note: error checking is performed - * after all kernels complete (not between kernel launches) to avoid synchronization overhead. + * @param is_required Whether each decoded field is required (parallel to decoded_field_indices). + * If a required field is missing and fail_on_errors is true, an exception is thrown. + * @param has_default_value Whether each decoded field has a default value (parallel to decoded_field_indices) + * @param default_ints Default values for int/long/enum fields (parallel to decoded_field_indices) + * @param default_floats Default values for float/double fields (parallel to decoded_field_indices) + * @param default_bools Default values for bool fields (parallel to decoded_field_indices) + * @param default_strings Default values for string/bytes fields (parallel to decoded_field_indices) + * @param enum_valid_values Valid enum values for each field (parallel to decoded_field_indices). + * Empty vector means not an enum field. Non-empty vector contains the + * valid enum values. Unknown enum values will be set to null. + * @param fail_on_errors Whether to throw on malformed messages or missing required fields. + * Note: error checking is performed after all kernels complete (not between kernel launches) + * to avoid synchronization overhead. * @return STRUCT column with total_num_fields children. Decoded fields contain the parsed data, * other fields contain all nulls. The STRUCT itself is always non-null. */ @@ -79,6 +90,13 @@ std::unique_ptr decode_protobuf_to_struct( std::vector const& field_numbers, std::vector const& all_types, std::vector const& encodings, + std::vector const& is_required, + std::vector const& has_default_value, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, bool fail_on_errors); } // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 4b419aa60c..e88064be0a 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -77,7 +77,7 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, int[] allTypeIds, int[] encodings) { return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, - allTypeIds, encodings, true); + allTypeIds, encodings, new boolean[decodedFieldIndices.length], true); } /** @@ -109,9 +109,152 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, int[] allTypeIds, int[] encodings, boolean failOnErrors) { + return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, + allTypeIds, encodings, new boolean[decodedFieldIndices.length], failOnErrors); + } + + /** + * Decode a protobuf message-per-row binary column into a STRUCT column. + * + * This method supports schema projection: only the fields specified in + * {@code decodedFieldIndices} will be decoded. Other fields in the output + * struct will contain all null values. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param totalNumFields Total number of fields in the output struct (including null columns). + * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. + * These must be sorted in ascending order. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). + * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). + * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): + * 0=default (varint), 1=fixed, 2=zigzag. + * @param isRequired Whether each decoded field is required (parallel to decodedFieldIndices). + * If a required field is missing and failOnErrors is true, an exception is thrown. + * @param failOnErrors if true, throw an exception on malformed protobuf messages or missing required fields. + * If false, return nulls for fields that cannot be parsed or are missing. + * Note: error checking is performed after all fields are processed, + * not between fields, to avoid synchronization overhead. + * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, + * other fields contain all nulls. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int totalNumFields, + int[] decodedFieldIndices, + int[] fieldNumbers, + int[] allTypeIds, + int[] encodings, + boolean[] isRequired, + boolean failOnErrors) { + int numFields = decodedFieldIndices.length; + return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, + allTypeIds, encodings, isRequired, + new boolean[numFields], // hasDefaultValue - all false + new long[numFields], // defaultInts + new double[numFields], // defaultFloats + new boolean[numFields], // defaultBools + new byte[numFields][], // defaultStrings - all null + failOnErrors); + } + + /** + * Decode a protobuf message-per-row binary column into a STRUCT column with default values support. + * + * This method supports schema projection: only the fields specified in + * {@code decodedFieldIndices} will be decoded. Other fields in the output + * struct will contain all null values. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param totalNumFields Total number of fields in the output struct (including null columns). + * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. + * These must be sorted in ascending order. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). + * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). + * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): + * 0=default (varint), 1=fixed, 2=zigzag. + * @param isRequired Whether each decoded field is required (parallel to decodedFieldIndices). + * If a required field is missing and failOnErrors is true, an exception is thrown. + * @param hasDefaultValue Whether each decoded field has a default value (parallel to decodedFieldIndices). + * @param defaultInts Default values for int/long/enum fields (parallel to decodedFieldIndices). + * @param defaultFloats Default values for float/double fields (parallel to decodedFieldIndices). + * @param defaultBools Default values for bool fields (parallel to decodedFieldIndices). + * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes (parallel to decodedFieldIndices). + * @param failOnErrors if true, throw an exception on malformed protobuf messages or missing required fields. + * If false, return nulls for fields that cannot be parsed or are missing. + * Note: error checking is performed after all fields are processed, + * not between fields, to avoid synchronization overhead. + * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, + * other fields contain all nulls. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int totalNumFields, + int[] decodedFieldIndices, + int[] fieldNumbers, + int[] allTypeIds, + int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + boolean failOnErrors) { + return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, + allTypeIds, encodings, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, defaultStrings, + new int[decodedFieldIndices.length][], failOnErrors); + } + + /** + * Decode a protobuf message-per-row binary column into a STRUCT column with default values + * and enum validation support. + * + * This method supports schema projection: only the fields specified in + * {@code decodedFieldIndices} will be decoded. Other fields in the output + * struct will contain all null values. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param totalNumFields Total number of fields in the output struct (including null columns). + * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. + * These must be sorted in ascending order. + * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). + * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). + * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): + * 0=default (varint), 1=fixed, 2=zigzag. + * @param isRequired Whether each decoded field is required (parallel to decodedFieldIndices). + * If a required field is missing and failOnErrors is true, an exception is thrown. + * @param hasDefaultValue Whether each decoded field has a default value (parallel to decodedFieldIndices). + * @param defaultInts Default values for int/long/enum fields (parallel to decodedFieldIndices). + * @param defaultFloats Default values for float/double fields (parallel to decodedFieldIndices). + * @param defaultBools Default values for bool fields (parallel to decodedFieldIndices). + * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes (parallel to decodedFieldIndices). + * @param enumValidValues Valid enum values for each field (null if not an enum). Unknown enum + * values will be set to null to match Spark CPU PERMISSIVE mode behavior. + * @param failOnErrors if true, throw an exception on malformed protobuf messages or missing required fields. + * If false, return nulls for fields that cannot be parsed or are missing. + * Note: error checking is performed after all fields are processed, + * not between fields, to avoid synchronization overhead. + * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, + * other fields contain all nulls. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int totalNumFields, + int[] decodedFieldIndices, + int[] fieldNumbers, + int[] allTypeIds, + int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors) { // Parameter validation if (decodedFieldIndices == null || fieldNumbers == null || - allTypeIds == null || encodings == null) { + allTypeIds == null || encodings == null || isRequired == null || + hasDefaultValue == null || defaultInts == null || defaultFloats == null || + defaultBools == null || defaultStrings == null || enumValidValues == null) { throw new IllegalArgumentException("Arrays must be non-null"); } if (totalNumFields < 0) { @@ -122,10 +265,18 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, "allTypeIds length (" + allTypeIds.length + ") must equal totalNumFields (" + totalNumFields + ")"); } - if (decodedFieldIndices.length != fieldNumbers.length || - decodedFieldIndices.length != encodings.length) { + int numDecodedFields = decodedFieldIndices.length; + if (fieldNumbers.length != numDecodedFields || + encodings.length != numDecodedFields || + isRequired.length != numDecodedFields || + hasDefaultValue.length != numDecodedFields || + defaultInts.length != numDecodedFields || + defaultFloats.length != numDecodedFields || + defaultBools.length != numDecodedFields || + defaultStrings.length != numDecodedFields || + enumValidValues.length != numDecodedFields) { throw new IllegalArgumentException( - "decodedFieldIndices/fieldNumbers/encodings must be the same length"); + "All decoded field arrays must have the same length as decodedFieldIndices"); } // Validate decoded field indices are in bounds and sorted @@ -165,7 +316,9 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, long handle = decodeToStruct(binaryInput.getNativeView(), totalNumFields, decodedFieldIndices, fieldNumbers, allTypeIds, - encodings, failOnErrors); + encodings, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, failOnErrors); return new ColumnVector(handle); } @@ -175,5 +328,12 @@ private static native long decodeToStruct(long binaryInputView, int[] fieldNumbers, int[] allTypeIds, int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, boolean failOnErrors); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index b3ed38b59a..f7450cda82 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -20,6 +20,7 @@ import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.ColumnView; import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector; import ai.rapids.cudf.HostColumnVector.*; import ai.rapids.cudf.Table; import org.junit.jupiter.api.Disabled; @@ -169,11 +170,30 @@ private static ColumnVector decodeAllFields(ColumnView binaryInput, int numFields = fieldNumbers.length; // When decoding all fields, decodedFieldIndices is [0, 1, 2, ..., n-1] int[] decodedFieldIndices = new int[numFields]; + boolean[] isRequired = new boolean[numFields]; // all false by default for (int i = 0; i < numFields; i++) { decodedFieldIndices[i] = i; } return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, - fieldNumbers, typeIds, encodings, failOnErrors); + fieldNumbers, typeIds, encodings, isRequired, failOnErrors); + } + + /** + * Helper method for tests with required field support. + */ + private static ColumnVector decodeAllFieldsWithRequired(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean[] isRequired, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] decodedFieldIndices = new int[numFields]; + for (int i = 0; i < numFields; i++) { + decodedFieldIndices[i] = i; + } + return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, + fieldNumbers, typeIds, encodings, isRequired, failOnErrors); } // ============================================================================ @@ -934,6 +954,666 @@ void testDoubleSpecialValues() { } } + // ============================================================================ + // Enum Tests (enums.as.ints=true semantics) + // ============================================================================ + + @Test + void testEnumAsInt() { + // message Msg { enum Color { RED=0; GREEN=1; BLUE=2; } Color c = 1; } + // c = GREEN (value 1) - encoded as varint + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(1); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumZeroValue() { + // Enum with value 0 (first/default enum value) + // c = RED (value 0) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(0); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumUnknownValue() { + // Protobuf allows unknown enum values - they should still be decoded as integers + // c = 999 (unknown value not in enum definition) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(999); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumNegativeValue() { + // Negative enum values are valid in protobuf (stored as unsigned varint) + // c = -1 (represented as 0xFFFFFFFF in protobuf wire format) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(-1L & 0xFFFFFFFFL))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(-1); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumMultipleFields() { + // message Msg { enum Status { OK=0; ERROR=1; } Status s1 = 1; int32 count = 2; Status s2 = 3; } + // s1 = ERROR (1), count = 42, s2 = OK (0) + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), // s1 = ERROR + box(tag(2, WT_VARINT)), box(encodeVarint(42)), // count = 42 + box(tag(3, WT_VARINT)), box(encodeVarint(0))); // s2 = OK + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedS1 = ColumnVector.fromBoxedInts(1); + ColumnVector expectedCount = ColumnVector.fromBoxedInts(42); + ColumnVector expectedS2 = ColumnVector.fromBoxedInts(0); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedS1, expectedCount, expectedS2); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumMissingField() { + // Enum field not present in message - should be null + Byte[] row = concat(box(tag(2, WT_VARINT)), box(encodeVarint(42))); // only count field + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedEnum = ColumnVector.fromBoxedInts((Integer) null); + ColumnVector expectedCount = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedEnum, expectedCount); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Required Field Tests + // ============================================================================ + + @Test + void testRequiredFieldPresent() { + // message Msg { required int64 id = 1; optional string name = 2; } + // Both fields present - should decode successfully + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(42)), + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedId = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedName = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, // id is required, name is optional + true)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testRequiredFieldMissing_Permissive() { + // Required field missing in permissive mode - should return null without exception + // message Msg { required int64 id = 1; optional string name = 2; } + // Only name field present, required id is missing + Byte[] row = concat( + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedId = ColumnVector.fromBoxedLongs((Long) null); + ColumnVector expectedName = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, // id is required, name is optional + false)) { // permissive mode - don't fail on errors + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testRequiredFieldMissing_Failfast() { + // Required field missing in failfast mode - should throw exception + // message Msg { required int64 id = 1; optional string name = 2; } + // Only name field present, required id is missing + Byte[] row = concat( + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, // id is required, name is optional + true)) { // failfast mode - should throw + } + }); + } + } + + @Test + void testMultipleRequiredFields_AllPresent() { + // message Msg { required int32 a = 1; required int64 b = 2; required string c = 3; } + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(10)), + box(tag(2, WT_VARINT)), box(encodeVarint(20)), + box(tag(3, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(10); + ColumnVector expectedB = ColumnVector.fromBoxedLongs(20L); + ColumnVector expectedC = ColumnVector.fromStrings("abc"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB, expectedC); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, true, true}, // all required + true)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testMultipleRequiredFields_SomeMissing_Failfast() { + // message Msg { required int32 a = 1; required int64 b = 2; required string c = 3; } + // Only field a is present, b and c are missing + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(10))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, true, true}, // all required + true)) { + } + }); + } + } + + @Test + void testOptionalFieldsOnly_NoValidation() { + // All fields optional - missing fields should not cause error + // message Msg { optional int32 a = 1; optional int64 b = 2; } + Byte[] row = new Byte[0]; // empty message + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts((Integer) null); + ColumnVector expectedB = ColumnVector.fromBoxedLongs((Long) null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // all optional + true)) { // even with failOnErrors=true, should succeed since all fields are optional + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testRequiredFieldWithMultipleRows() { + // Test required field validation across multiple rows + // Row 0: required field present + // Row 1: required field missing (should cause error in failfast mode) + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + Byte[] row1 = new Byte[0]; // empty - required field missing + + try (Table input = new Table.TestBuilder().column(row0, row1).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, // required + true)) { + } + }); + } + } + + // ============================================================================ + // Default Value Tests (API accepts parameters, CUDA fill not yet implemented) + // ============================================================================ + + /** + * Helper method for tests with default value support. + */ + private static ColumnVector decodeAllFieldsWithDefaults(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] decodedFieldIndices = new int[numFields]; + for (int i = 0; i < numFields; i++) { + decodedFieldIndices[i] = i; + } + return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, + fieldNumbers, typeIds, encodings, isRequired, + hasDefaultValue, defaultInts, defaultFloats, + defaultBools, defaultStrings, failOnErrors); + } + + @Test + void testDefaultValueForMissingFields() { + // Test that missing fields with default values return the defaults + Byte[] row = new Byte[0]; // empty message + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + // With default values set, missing fields should return the default values + ColumnVector expectedA = ColumnVector.fromBoxedInts(42); + ColumnVector expectedB = ColumnVector.fromBoxedLongs(100L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // not required + new boolean[]{true, true}, // has default value + new long[]{42, 100}, // default int values (42, 100) + new double[]{0.0, 0.0}, // default float values (unused for int fields) + new boolean[]{false, false}, // default bool values (unused) + new byte[][]{null, null}, // default string values (unused) + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultValueFieldPresent_OverridesDefault() { + // When field is present, use the actual value (not the default) + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(99)), + box(tag(2, WT_VARINT)), box(encodeVarint(200))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(99); + ColumnVector expectedB = ColumnVector.fromBoxedLongs(200L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // not required + new boolean[]{true, true}, // has default value + new long[]{42, 100}, // default values - NOT used since field is present + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + false)) { + // Actual values should be used, not defaults + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultIntValue() { + // optional int32 count = 1 [default = 42]; + // Empty message should return the default value + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, // not required + new boolean[]{true}, // has default + new long[]{42}, // default = 42 + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultBoolValue() { + // optional bool flag = 1 [default = true]; + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedBool = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedBool); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.BOOL8.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{true}, // default = true + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultFloatValue() { + // optional double rate = 1 [default = 3.14]; + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(3.14); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedDouble); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.FLOAT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{3.14}, // default = 3.14 + new boolean[]{false}, + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultInt64Value() { + // optional int64 big_num = 1 [default = 9876543210]; + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(9876543210L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{9876543210L}, // default = 9876543210 + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testMixedDefaultAndNonDefaultFields() { + // optional int32 a = 1 [default = 42]; + // optional int64 b = 2; (no default) + // optional bool c = 3 [default = true]; + // Empty message: a=42, b=null, c=true + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(42); + ColumnVector expectedB = ColumnVector.fromBoxedLongs((Long) null); // no default + ColumnVector expectedC = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB, expectedC); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.BOOL8.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false}, // not required + new boolean[]{true, false, true}, // a and c have defaults, b doesn't + new long[]{42, 0, 0}, // default for a + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, true}, // default for c + new byte[][]{null, null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultValueWithPartialMessage() { + // optional int32 a = 1 [default = 42]; + // optional int64 b = 2 [default = 100]; + // Message has only field b set, a should use default + Byte[] row = concat( + box(tag(2, WT_VARINT)), box(encodeVarint(999))); // b = 999 + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(42); // default + ColumnVector expectedB = ColumnVector.fromBoxedLongs(999L); // actual value + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // not required + new boolean[]{true, true}, // both have defaults + new long[]{42, 100}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringValue() { + // optional string name = 1 [default = "hello"]; + // Empty message should return the default string + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, // not required + new boolean[]{true}, // has default + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"hello".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringValueEmpty() { + // optional string name = 1 [default = ""]; + // Empty message with empty default string + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings(""); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{new byte[0]}, // empty default string + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringValueWithPresent() { + // optional string name = 1 [default = "default"]; + // Message has actual value, should override default + byte[] strBytesRaw = "actual".getBytes(java.nio.charset.StandardCharsets.UTF_8); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(strBytesRaw.length)), + box(strBytesRaw)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("actual"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"default".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringWithMixedFields() { + // optional int32 count = 1 [default = 42]; + // optional string name = 2 [default = "test"]; + // Empty message should return both defaults + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStr = ColumnVector.fromStrings("test"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt, expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{true, true}, + new long[]{42, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, "test".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringMultipleRows() { + // optional string name = 1 [default = "default"]; + // Multiple rows: empty, has value, empty + Byte[] row1 = new Byte[0]; // will use default + byte[] strBytesRaw = "row2val".getBytes(java.nio.charset.StandardCharsets.UTF_8); + Byte[] row2 = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(strBytesRaw.length)), + box(strBytesRaw)); + Byte[] row3 = new Byte[0]; // will use default + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row1, row2, row3}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("default", "row2val", "default"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"default".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + // ============================================================================ // Tests for Features Not Yet Implemented (Disabled) // ============================================================================ @@ -1170,4 +1850,176 @@ void testMultiFieldPerformance() { } } } + + // ============================================================================ + // Enum Validation Tests + // ============================================================================ + + /** + * Helper method that wraps decodeToStruct with enum validation support. + */ + private static ColumnVector decodeAllFieldsWithEnums(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + int[][] enumValidValues, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] decodedFieldIndices = new int[numFields]; + boolean[] isRequired = new boolean[numFields]; + boolean[] hasDefaultValue = new boolean[numFields]; + long[] defaultInts = new long[numFields]; + double[] defaultFloats = new double[numFields]; + boolean[] defaultBools = new boolean[numFields]; + byte[][] defaultStrings = new byte[numFields][]; + for (int i = 0; i < numFields; i++) { + decodedFieldIndices[i] = i; + } + return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, + fieldNumbers, typeIds, encodings, isRequired, + hasDefaultValue, defaultInts, defaultFloats, + defaultBools, defaultStrings, enumValidValues, failOnErrors); + } + + @Test + void testEnumValidValue() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Msg { Color color = 1; } + // Test with valid enum value (GREEN = 1) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedColor = ColumnVector.fromBoxedInts(1); // GREEN + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedColor); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values: RED=0, GREEN=1, BLUE=2 + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumUnknownValueReturnsNullRow() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Msg { Color color = 1; } + // Test with unknown enum value (999 is not defined) + // The entire struct row should be null (matching Spark CPU PERMISSIVE mode) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values: RED=0, GREEN=1, BLUE=2 + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + // The struct itself should be null (not just the field) + assert actualStruct.getNullCount() == 1 : "Struct row should be null for unknown enum"; + assert hostStruct.isNull(0) : "Row 0 should be null"; + } + } + + @Test + void testEnumMixedValidAndUnknown() { + // Test multiple rows with mix of valid and unknown enum values + // Rows with unknown enum values should have null struct (not just null field) + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); // RED (valid) -> struct valid + Byte[] row1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // unknown -> struct null + Byte[] row2 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(2))); // BLUE (valid) -> struct valid + Byte[] row3 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(-1))); // negative (unknown) -> struct null + + try (Table input = new Table.TestBuilder().column(row0, row1, row2, row3).build(); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + // Check struct-level nulls + assert actualStruct.getNullCount() == 2 : "Should have 2 null rows (rows 1 and 3)"; + assert !hostStruct.isNull(0) : "Row 0 should be valid"; + assert hostStruct.isNull(1) : "Row 1 should be null (unknown enum 999)"; + assert !hostStruct.isNull(2) : "Row 2 should be valid"; + assert hostStruct.isNull(3) : "Row 3 should be null (unknown enum -1)"; + } + } + + @Test + void testEnumWithOtherFields_NullsEntireRow() { + // message Msg { Color color = 1; int32 count = 2; } + // Test that unknown enum value nulls the ENTIRE struct row (not just the enum field) + // This matches Spark CPU PERMISSIVE mode behavior + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(999)), // unknown enum value + box(tag(2, WT_VARINT)), box(encodeVarint(42))); // count = 42 + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}, null}, // first field is enum, second is regular int (no validation) + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + // The entire struct row should be null + assert actualStruct.getNullCount() == 1 : "Struct row should be null"; + assert hostStruct.isNull(0) : "Row 0 should be null due to unknown enum"; + } + } + + @Test + void testEnumMissingFieldDoesNotNullRow() { + // Missing enum field should return null for the field, but NOT null the entire row + // Only unknown enum values (present but invalid) trigger row-level null + Byte[] row = new Byte[0]; // empty message + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedColor = ColumnVector.fromBoxedInts((Integer) null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedColor); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values + false)) { + // Struct row should be valid (not null), only the field is null + assert actualStruct.getNullCount() == 0 : "Struct row should NOT be null for missing field"; + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumValidWithOtherFields() { + // message Msg { Color color = 1; int32 count = 2; } + // Test that valid enum value works correctly with other fields + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), // GREEN (valid) + box(tag(2, WT_VARINT)), box(encodeVarint(42))); // count = 42 + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedColor = ColumnVector.fromBoxedInts(1); + ColumnVector expectedCount = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedColor, expectedCount); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}, null}, // first field is enum, second is regular int + false)) { + // Struct row should be valid with correct values + assert actualStruct.getNullCount() == 0 : "Struct row should be valid"; + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } } From 3b37d28fe610f325368562513978701f84ecde44 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 27 Jan 2026 16:15:19 +0800 Subject: [PATCH 014/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 35 ++- src/main/cpp/src/protobuf.cu | 525 ++++++++++++++++++------------- src/main/cpp/src/protobuf.hpp | 9 +- 3 files changed, 342 insertions(+), 227 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 37f4b16775..4dfcda5895 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -72,10 +72,8 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, int num_decoded_fields = n_decoded_indices.size(); // Validate array sizes - if (n_field_numbers.size() != num_decoded_fields || - n_encodings.size() != num_decoded_fields || - n_is_required.size() != num_decoded_fields || - n_has_default.size() != num_decoded_fields || + if (n_field_numbers.size() != num_decoded_fields || n_encodings.size() != num_decoded_fields || + n_is_required.size() != num_decoded_fields || n_has_default.size() != num_decoded_fields || n_default_ints.size() != num_decoded_fields || n_default_floats.size() != num_decoded_fields || n_default_bools.size() != num_decoded_fields) { @@ -94,7 +92,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, std::vector decoded_indices(n_decoded_indices.begin(), n_decoded_indices.end()); std::vector field_nums(n_field_numbers.begin(), n_field_numbers.end()); std::vector encs(n_encodings.begin(), n_encodings.end()); - + // Convert jboolean arrays to std::vector std::vector required_flags; std::vector has_default_flags; @@ -120,11 +118,10 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, if (byte_arr == nullptr) { default_string_values.emplace_back(); // empty vector for null } else { - jsize len = env->GetArrayLength(byte_arr); + jsize len = env->GetArrayLength(byte_arr); jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); - default_string_values.emplace_back( - reinterpret_cast(bytes), - reinterpret_cast(bytes) + len); + default_string_values.emplace_back(reinterpret_cast(bytes), + reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); } } @@ -138,7 +135,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, if (int_arr == nullptr) { enum_values.emplace_back(); // empty vector for null (not an enum field) } else { - jsize len = env->GetArrayLength(int_arr); + jsize len = env->GetArrayLength(int_arr); jint* ints = env->GetIntArrayElements(int_arr, nullptr); enum_values.emplace_back(ints, ints + len); env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); @@ -153,10 +150,20 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, all_types.emplace_back(cudf::jni::make_data_type(n_all_type_ids[i], 0)); } - auto result = spark_rapids_jni::decode_protobuf_to_struct( - *input, total_num_fields, decoded_indices, field_nums, all_types, encs, - required_flags, has_default_flags, default_int_values, default_float_values, - default_bool_values, default_string_values, enum_values, fail_on_errors); + auto result = spark_rapids_jni::decode_protobuf_to_struct(*input, + total_num_fields, + decoded_indices, + field_nums, + all_types, + encs, + required_flags, + has_default_flags, + default_int_values, + default_float_values, + default_bool_values, + default_string_values, + enum_values, + fail_on_errors); return cudf::jni::release_as_jlong(result); } JNI_CATCH(env, 0); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a886dde4cc..93518a1d0d 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -60,7 +60,7 @@ struct field_location { * Field descriptor passed to the scanning kernel. */ struct field_descriptor { - int field_number; // Protobuf field number + int field_number; // Protobuf field number int expected_wire_type; // Expected wire type for this field }; @@ -73,8 +73,8 @@ __device__ inline bool read_varint(uint8_t const* cur, uint64_t& out, int& bytes) { - out = 0; - bytes = 0; + out = 0; + bytes = 0; int shift = 0; while (cur < end && bytes < 10) { uint8_t b = *cur++; @@ -171,7 +171,7 @@ __global__ void scan_all_fields_kernel( cudf::column_device_view const d_in, field_descriptor const* field_descs, // [num_fields] int num_fields, - field_location* locations, // [num_rows * num_fields] row-major + field_location* locations, // [num_rows * num_fields] row-major int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -246,8 +246,7 @@ __global__ void scan_all_fields_kernel( return; } // Record offset pointing to the actual data (after length prefix) - locations[row * num_fields + f] = {data_offset + len_bytes, - static_cast(len)}; + locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; } else { // For fixed-size and varint fields, record offset and compute length int field_size = get_wire_type_size(wt, cur, stop); @@ -282,7 +281,7 @@ __global__ void scan_all_fields_kernel( template __global__ void extract_varint_from_locations_kernel( uint8_t const* message_data, - cudf::size_type const* offsets, // List offsets for each row + cudf::size_type const* offsets, // List offsets for each row cudf::size_type base_offset, field_location const* locations, // [num_rows * num_fields] int field_idx, @@ -291,7 +290,7 @@ __global__ void extract_varint_from_locations_kernel( bool* valid, int num_rows, int* error_flag, - bool has_default = false, + bool has_default = false, int64_t default_value = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -332,19 +331,18 @@ __global__ void extract_varint_from_locations_kernel( * Supports default values for missing fields. */ template -__global__ void extract_fixed_from_locations_kernel( - uint8_t const* message_data, - cudf::size_type const* offsets, - cudf::size_type base_offset, - field_location const* locations, - int field_idx, - int num_fields, - OutT* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - OutT default_value = OutT{}) +__global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, + cudf::size_type const* offsets, + cudf::size_type base_offset, + field_location const* locations, + int field_idx, + int num_fields, + OutT* out, + bool* valid, + int num_rows, + int* error_flag, + bool has_default = false, + OutT default_value = OutT{}) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -402,14 +400,14 @@ __global__ void copy_varlen_data_kernel( int32_t const* output_offsets, // Pre-computed output offsets (prefix sum) char* output_data, int num_rows, - bool has_default = false, + bool has_default = false, uint8_t const* default_data = nullptr, - int32_t default_length = 0) + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; - auto loc = locations[row * num_fields + field_idx]; + auto loc = locations[row * num_fields + field_idx]; char* dst = output_data + output_offsets[row]; if (loc.offset < 0) { @@ -424,8 +422,8 @@ __global__ void copy_varlen_data_kernel( if (loc.length == 0) return; - auto row_start = input_offsets[row] - base_offset; - uint8_t const* src = message_data + row_start + loc.offset; + auto row_start = input_offsets[row] - base_offset; + uint8_t const* src = message_data + row_start + loc.offset; // Copy data for (int i = 0; i < loc.length; i++) { @@ -437,14 +435,13 @@ __global__ void copy_varlen_data_kernel( * Kernel to extract lengths from locations for prefix sum. * Supports default values for missing fields. */ -__global__ void extract_lengths_kernel( - field_location const* locations, - int field_idx, - int num_fields, - int32_t* lengths, - int num_rows, - bool has_default = false, - int32_t default_length = 0) +__global__ void extract_lengths_kernel(field_location const* locations, + int field_idx, + int num_fields, + int32_t* lengths, + int num_rows, + bool has_default = false, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -487,7 +484,7 @@ int get_expected_wire_type(cudf::type_id type_id, int encoding) case cudf::type_id::UINT64: if (encoding == spark_rapids_jni::ENC_FIXED) { return (type_id == cudf::type_id::INT32 || type_id == cudf::type_id::UINT32) ? WT_32BIT - : WT_64BIT; + : WT_64BIT; } return WT_VARINT; case cudf::type_id::FLOAT32: return WT_32BIT; @@ -501,11 +498,10 @@ int get_expected_wire_type(cudf::type_id type_id, int encoding) /** * Create an all-null column of the specified type. */ -std::unique_ptr make_null_column( - cudf::data_type dtype, - cudf::size_type num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { if (num_rows == 0) { return cudf::make_empty_column(dtype); } @@ -540,20 +536,15 @@ std::unique_ptr make_null_column( // Offsets: all zeros rmm::device_uvector offsets(num_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); // Empty child auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::INT8}, - 0, - rmm::device_buffer{}, - rmm::device_buffer{}, - 0); + cudf::data_type{cudf::type_id::INT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); // All null mask auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); @@ -604,33 +595,33 @@ __global__ void check_required_fields_kernel( * If a value is not in the valid set: * 1. Mark the field as invalid (valid[row] = false) * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) - * + * * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is * encountered, the entire struct row is set to null (not just the enum field). - * + * * The valid_values array must be sorted for binary search. */ __global__ void validate_enum_values_kernel( - int32_t const* values, // [num_rows] extracted enum values - bool* valid, // [num_rows] field validity flags (will be modified) - bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) - int32_t const* valid_enum_values, // sorted array of valid enum values - int num_valid_values, // size of valid_enum_values + int32_t const* values, // [num_rows] extracted enum values + bool* valid, // [num_rows] field validity flags (will be modified) + bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) + int32_t const* valid_enum_values, // sorted array of valid enum values + int num_valid_values, // size of valid_enum_values int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; - + // Skip if already invalid (field was missing) - missing field is not an enum error if (!valid[row]) return; - + int32_t val = values[row]; - + // Binary search for the value in valid_enum_values - int left = 0; - int right = num_valid_values - 1; + int left = 0; + int right = num_valid_values - 1; bool found = false; - + while (left <= right) { int mid = left + (right - left) / 2; if (valid_enum_values[mid] == val) { @@ -642,7 +633,7 @@ __global__ void validate_enum_values_kernel( right = mid - 1; } } - + // If not found, mark as invalid if (!found) { valid[row] = false; @@ -694,9 +685,9 @@ std::unique_ptr decode_protobuf_to_struct( CUDF_EXPECTS(default_strings.size() == field_numbers.size(), "default_strings and field_numbers must have the same length"); - auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto rows = binary_input.size(); + auto const stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + auto rows = binary_input.size(); auto num_decoded_fields = static_cast(field_numbers.size()); // Handle zero-row case @@ -726,8 +717,8 @@ std::unique_ptr decode_protobuf_to_struct( // Prepare field descriptors for the scanning kernel std::vector h_field_descs(num_decoded_fields); for (int i = 0; i < num_decoded_fields; i++) { - int schema_idx = decoded_field_indices[i]; - h_field_descs[i].field_number = field_numbers[i]; + int schema_idx = decoded_field_indices[i]; + h_field_descs[i].field_number = field_numbers[i]; h_field_descs[i].expected_wire_type = get_expected_wire_type(all_types[schema_idx].id(), encodings[i]); } @@ -748,15 +739,16 @@ std::unique_ptr decode_protobuf_to_struct( CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); // Check if any field has enum validation - bool has_enum_fields = std::any_of(enum_valid_values.begin(), enum_valid_values.end(), - [](auto const& v) { return !v.empty(); }); - + bool has_enum_fields = std::any_of( + enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); + // Track rows with invalid enum values (used to null entire struct row) // This matches Spark CPU PERMISSIVE mode behavior rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? rows : 0, stream, mr); if (has_enum_fields) { // Initialize all to false (no invalid enums yet) - CUDF_CUDA_TRY(cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, rows * sizeof(bool), stream.value())); + CUDF_CUDA_TRY( + cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, rows * sizeof(bool), stream.value())); } auto const threads = 256; @@ -772,8 +764,8 @@ std::unique_ptr decode_protobuf_to_struct( // Check required fields (after scan pass) // ========================================================================= // Only check if any field is required to avoid unnecessary kernel launch - bool has_required_fields = std::any_of(is_required.begin(), is_required.end(), - [](bool b) { return b; }); + bool has_required_fields = + std::any_of(is_required.begin(), is_required.end(), [](bool b) { return b; }); if (has_required_fields) { // Copy is_required flags to device // Note: std::vector is special (bitfield), so we convert to uint8_t @@ -787,22 +779,18 @@ std::unique_ptr decode_protobuf_to_struct( num_decoded_fields * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); - + check_required_fields_kernel<<>>( d_locations.data(), d_is_required.data(), num_decoded_fields, rows, d_error.data()); } // Get message data pointer and offsets for pass 2 - auto const* message_data = - reinterpret_cast(in_list.child().data()); + auto const* message_data = reinterpret_cast(in_list.child().data()); auto const* list_offsets = in_list.offsets().data(); // Get the base offset by copying from device to host cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, - list_offsets, - sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, - stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); // ========================================================================= @@ -812,8 +800,7 @@ std::unique_ptr decode_protobuf_to_struct( int decoded_idx = 0; for (int schema_idx = 0; schema_idx < total_num_fields; schema_idx++) { - if (decoded_idx < num_decoded_fields && - decoded_field_indices[decoded_idx] == schema_idx) { + if (decoded_idx < num_decoded_fields && decoded_field_indices[decoded_idx] == schema_idx) { // This field needs to be decoded auto const dt = all_types[schema_idx]; auto const enc = encodings[decoded_idx]; @@ -822,21 +809,21 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::BOOL8: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; + bool has_def = has_default_value[decoded_idx]; int64_t def_val = has_def ? (default_bools[decoded_idx] ? 1 : 0) : 0; - extract_varint_from_locations_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_val); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_val); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -846,26 +833,53 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::INT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; int32_t def_fixed = static_cast(def_int); if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_fixed); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_fixed); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } - + // Validate enum values if this is an enum field // enum_valid_values[decoded_idx] is non-empty for enum fields auto const& valid_enums = enum_valid_values[decoded_idx]; @@ -877,13 +891,17 @@ std::unique_ptr decode_protobuf_to_struct( valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - + // Validate enum values - unknown values will null the entire row validate_enum_values_kernel<<>>( - out.data(), valid.data(), d_row_has_invalid_enum.data(), - d_valid_enums.data(), static_cast(valid_enums.size()), rows); + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + rows); } - + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -893,19 +911,37 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::UINT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; uint32_t def_fixed = static_cast(def_int); if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_fixed); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_fixed); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = @@ -916,23 +952,50 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::INT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; + bool has_def = has_default_value[decoded_idx]; int64_t def_int = has_def ? default_ints[decoded_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = @@ -943,19 +1006,37 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::UINT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; uint64_t def_fixed = static_cast(def_int); if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_fixed); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_fixed); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = @@ -966,12 +1047,21 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::FLOAT32: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; + bool has_def = has_default_value[decoded_idx]; float def_float = has_def ? static_cast(default_floats[decoded_idx]) : 0.0f; - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_float); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_float); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -981,12 +1071,21 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::FLOAT64: { rmm::device_uvector out(rows, stream, mr); rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; + bool has_def = has_default_value[decoded_idx]; double def_double = has_def ? default_floats[decoded_idx] : 0.0; - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), decoded_idx, - num_decoded_fields, out.data(), valid.data(), rows, d_error.data(), - has_def, def_double); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_double); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); all_children[schema_idx] = std::make_unique(dt, rows, out.release(), std::move(mask), null_count); @@ -995,9 +1094,9 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::STRING: { // Check for default value - bool has_def = has_default_value[decoded_idx]; + bool has_def = has_default_value[decoded_idx]; auto const& def_str = default_strings[decoded_idx]; - int32_t def_len = has_def ? static_cast(def_str.size()) : 0; + int32_t def_len = has_def ? static_cast(def_str.size()) : 0; // Copy default string to device if needed rmm::device_uvector d_default_str(def_len, stream, mr); @@ -1011,9 +1110,13 @@ std::unique_ptr decode_protobuf_to_struct( // Extract lengths and compute output offsets via prefix sum rmm::device_uvector lengths(rows, stream, mr); - extract_lengths_kernel<<>>( - d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows, - has_def, def_len); + extract_lengths_kernel<<>>(d_locations.data(), + decoded_idx, + num_decoded_fields, + lengths.data(), + rows, + has_def, + def_len); rmm::device_uvector output_offsets(rows + 1, stream, mr); thrust::exclusive_scan( @@ -1045,19 +1148,18 @@ std::unique_ptr decode_protobuf_to_struct( // Allocate and copy character data rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_varlen_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - output_offsets.data(), - chars.data(), - rows, - has_def, - d_default_str.data(), - def_len); + copy_varlen_data_kernel<<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + output_offsets.data(), + chars.data(), + rows, + has_def, + d_default_str.data(), + def_len); } // Create validity mask (field found OR has default = valid) @@ -1067,18 +1169,18 @@ std::unique_ptr decode_protobuf_to_struct( thrust::make_counting_iterator(0), thrust::make_counting_iterator(rows), valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__(auto row) { + [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( + auto row) { return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); // Create offsets column - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, - rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); // Create strings column using offsets + chars buffer all_children[schema_idx] = cudf::make_strings_column( @@ -1089,9 +1191,9 @@ std::unique_ptr decode_protobuf_to_struct( case cudf::type_id::LIST: { // For protobuf bytes: create LIST directly (optimization #2) // Check for default value - bool has_def = has_default_value[decoded_idx]; + bool has_def = has_default_value[decoded_idx]; auto const& def_bytes = default_strings[decoded_idx]; - int32_t def_len = has_def ? static_cast(def_bytes.size()) : 0; + int32_t def_len = has_def ? static_cast(def_bytes.size()) : 0; // Copy default bytes to device if needed rmm::device_uvector d_default_bytes(def_len, stream, mr); @@ -1105,9 +1207,13 @@ std::unique_ptr decode_protobuf_to_struct( // Extract lengths and compute output offsets via prefix sum rmm::device_uvector lengths(rows, stream, mr); - extract_lengths_kernel<<>>( - d_locations.data(), decoded_idx, num_decoded_fields, lengths.data(), rows, - has_def, def_len); + extract_lengths_kernel<<>>(d_locations.data(), + decoded_idx, + num_decoded_fields, + lengths.data(), + rows, + has_def, + def_len); rmm::device_uvector output_offsets(rows + 1, stream, mr); thrust::exclusive_scan( @@ -1161,26 +1267,25 @@ std::unique_ptr decode_protobuf_to_struct( thrust::make_counting_iterator(0), thrust::make_counting_iterator(rows), valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__(auto row) { + [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( + auto row) { return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); // Create offsets column - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, - rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); // Create INT8 child column directly (no intermediate strings column!) - auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::INT8}, - total_bytes, - child_data.release(), - rmm::device_buffer{}, - 0); + auto child_col = std::make_unique(cudf::data_type{cudf::type_id::INT8}, + total_bytes, + child_data.release(), + rmm::device_buffer{}, + 0); all_children[schema_idx] = cudf::make_lists_column(rows, std::move(offsets_col), @@ -1214,8 +1319,8 @@ std::unique_ptr decode_protobuf_to_struct( cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, - "Malformed protobuf message, unsupported wire type, or missing required field"); + CUDF_EXPECTS(h_error == 0, + "Malformed protobuf message, unsupported wire type, or missing required field"); } // Build the final struct @@ -1223,7 +1328,7 @@ std::unique_ptr decode_protobuf_to_struct( // This matches Spark CPU PERMISSIVE mode: unknown enum values null the entire row cudf::size_type struct_null_count = 0; rmm::device_buffer struct_mask{0, stream, mr}; - + if (has_enum_fields) { // Create struct null mask: row is valid if it has NO invalid enums auto [mask, null_count] = cudf::detail::valid_if( @@ -1234,10 +1339,10 @@ std::unique_ptr decode_protobuf_to_struct( }, stream, mr); - struct_mask = std::move(mask); + struct_mask = std::move(mask); struct_null_count = null_count; } - + return cudf::make_structs_column( rows, std::move(all_children), struct_null_count, std::move(struct_mask), stream, mr); } diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 817374dc66..0e398af39d 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -63,13 +63,16 @@ constexpr int ENC_ZIGZAG = 2; * @param total_num_fields Total number of fields in the output struct (including null columns) * @param decoded_field_indices Indices into the output struct for fields that should be decoded. * Fields not in this list will be null columns in the output. - * @param field_numbers Protobuf field numbers for decoded fields (parallel to decoded_field_indices) + * @param field_numbers Protobuf field numbers for decoded fields (parallel to + * decoded_field_indices) * @param all_types Output cudf data types for ALL fields in the struct (size = total_num_fields) * @param encodings Encoding type for each decoded field (0=default, 1=fixed, 2=zigzag) * (parallel to decoded_field_indices) * @param is_required Whether each decoded field is required (parallel to decoded_field_indices). - * If a required field is missing and fail_on_errors is true, an exception is thrown. - * @param has_default_value Whether each decoded field has a default value (parallel to decoded_field_indices) + * If a required field is missing and fail_on_errors is true, an exception is + * thrown. + * @param has_default_value Whether each decoded field has a default value (parallel to + * decoded_field_indices) * @param default_ints Default values for int/long/enum fields (parallel to decoded_field_indices) * @param default_floats Default values for float/double fields (parallel to decoded_field_indices) * @param default_bools Default values for bool fields (parallel to decoded_field_indices) From 46d21304da9c72626e92810be4c91a96f66bb2fa Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 30 Jan 2026 18:16:07 +0800 Subject: [PATCH 015/107] Support nested types Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 156 + src/main/cpp/src/protobuf.cu | 4571 ++++++++++++++--- src/main/cpp/src/protobuf.hpp | 50 + .../com/nvidia/spark/rapids/jni/Protobuf.java | 117 + .../nvidia/spark/rapids/jni/ProtobufTest.java | 74 +- 5 files changed, 4211 insertions(+), 757 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 4dfcda5895..f626ff291e 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -169,4 +169,160 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, JNI_CATCH(env, 0); } +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_Protobuf_decodeNestedToStruct(JNIEnv* env, + jclass, + jlong binary_input_view, + jintArray field_numbers, + jintArray parent_indices, + jintArray depth_levels, + jintArray wire_types, + jintArray output_type_ids, + jintArray encodings, + jbooleanArray is_repeated, + jbooleanArray is_required, + jbooleanArray has_default_value, + jlongArray default_ints, + jdoubleArray default_floats, + jbooleanArray default_bools, + jobjectArray default_strings, + jobjectArray enum_valid_values, + jboolean fail_on_errors) +{ + JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); + JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); + JNI_NULL_CHECK(env, parent_indices, "parent_indices is null", 0); + JNI_NULL_CHECK(env, depth_levels, "depth_levels is null", 0); + JNI_NULL_CHECK(env, wire_types, "wire_types is null", 0); + JNI_NULL_CHECK(env, output_type_ids, "output_type_ids is null", 0); + JNI_NULL_CHECK(env, encodings, "encodings is null", 0); + JNI_NULL_CHECK(env, is_repeated, "is_repeated is null", 0); + JNI_NULL_CHECK(env, is_required, "is_required is null", 0); + JNI_NULL_CHECK(env, has_default_value, "has_default_value is null", 0); + JNI_NULL_CHECK(env, default_ints, "default_ints is null", 0); + JNI_NULL_CHECK(env, default_floats, "default_floats is null", 0); + JNI_NULL_CHECK(env, default_bools, "default_bools is null", 0); + JNI_NULL_CHECK(env, default_strings, "default_strings is null", 0); + JNI_NULL_CHECK(env, enum_valid_values, "enum_valid_values is null", 0); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto const* input = reinterpret_cast(binary_input_view); + + cudf::jni::native_jintArray n_field_numbers(env, field_numbers); + cudf::jni::native_jintArray n_parent_indices(env, parent_indices); + cudf::jni::native_jintArray n_depth_levels(env, depth_levels); + cudf::jni::native_jintArray n_wire_types(env, wire_types); + cudf::jni::native_jintArray n_output_type_ids(env, output_type_ids); + cudf::jni::native_jintArray n_encodings(env, encodings); + cudf::jni::native_jbooleanArray n_is_repeated(env, is_repeated); + cudf::jni::native_jbooleanArray n_is_required(env, is_required); + cudf::jni::native_jbooleanArray n_has_default(env, has_default_value); + cudf::jni::native_jlongArray n_default_ints(env, default_ints); + cudf::jni::native_jdoubleArray n_default_floats(env, default_floats); + cudf::jni::native_jbooleanArray n_default_bools(env, default_bools); + + int num_fields = n_field_numbers.size(); + + // Validate array sizes + if (n_parent_indices.size() != num_fields || + n_depth_levels.size() != num_fields || + n_wire_types.size() != num_fields || + n_output_type_ids.size() != num_fields || + n_encodings.size() != num_fields || + n_is_repeated.size() != num_fields || + n_is_required.size() != num_fields || + n_has_default.size() != num_fields || + n_default_ints.size() != num_fields || + n_default_floats.size() != num_fields || + n_default_bools.size() != num_fields) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "All field arrays must have the same length", + 0); + } + + // Build schema descriptors + std::vector schema; + schema.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + schema.push_back({ + n_field_numbers[i], + n_parent_indices[i], + n_depth_levels[i], + n_wire_types[i], + static_cast(n_output_type_ids[i]), + n_encodings[i], + n_is_repeated[i] != 0, + n_is_required[i] != 0, + n_has_default[i] != 0 + }); + } + + // Build output types + std::vector schema_output_types; + schema_output_types.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + schema_output_types.emplace_back(static_cast(n_output_type_ids[i])); + } + + // Convert boolean arrays + std::vector default_bool_values; + default_bool_values.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + default_bool_values.push_back(n_default_bools[i] != 0); + } + + // Convert default values + std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); + std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); + + // Convert default string values + std::vector> default_string_values; + default_string_values.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(default_strings, i)); + if (byte_arr == nullptr) { + default_string_values.emplace_back(); + } else { + jsize len = env->GetArrayLength(byte_arr); + jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + default_string_values.emplace_back(reinterpret_cast(bytes), + reinterpret_cast(bytes) + len); + env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + } + } + + // Convert enum valid values + std::vector> enum_values; + enum_values.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + jintArray int_arr = static_cast(env->GetObjectArrayElement(enum_valid_values, i)); + if (int_arr == nullptr) { + enum_values.emplace_back(); + } else { + jsize len = env->GetArrayLength(int_arr); + jint* ints = env->GetIntArrayElements(int_arr, nullptr); + enum_values.emplace_back(ints, ints + len); + env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + } + } + + auto result = spark_rapids_jni::decode_nested_protobuf_to_struct( + *input, + schema, + schema_output_types, + default_int_values, + default_float_values, + default_bool_values, + default_string_values, + enum_values, + fail_on_errors); + + return cudf::jni::release_as_jlong(result); + } + JNI_CATCH(env, 0); +} + } // extern "C" diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 93518a1d0d..e887e9b017 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -35,6 +35,11 @@ #include #include +#include +#include +#include +#include + namespace { // Wire type constants @@ -64,6 +69,38 @@ struct field_descriptor { int expected_wire_type; // Expected wire type for this field }; +/** + * Information about repeated field occurrences in a row. + */ +struct repeated_field_info { + int32_t count; // Number of occurrences in this row + int32_t total_length; // Total bytes for all occurrences (for varlen fields) +}; + +/** + * Location of a single occurrence of a repeated field. + */ +struct repeated_occurrence { + int32_t row_idx; // Which row this occurrence belongs to + int32_t offset; // Offset within the message + int32_t length; // Length of the field data +}; + +/** + * Device-side descriptor for nested schema fields. + */ +struct device_nested_field_descriptor { + int field_number; + int parent_idx; + int depth; + int wire_type; + int output_type_id; + int encoding; + bool is_repeated; + bool is_required; + bool has_default_value; +}; + // ============================================================================ // Device helper functions // ============================================================================ @@ -135,6 +172,37 @@ __device__ inline bool skip_field(uint8_t const* cur, return true; } +/** + * Get the data offset and length for a field at current position. + * Returns true on success, false on error. + */ +__device__ inline bool get_field_data_location(uint8_t const* cur, + uint8_t const* end, + int wt, + int32_t& data_offset, + int32_t& data_length) +{ + if (wt == WT_LEN) { + // For length-delimited, read the length prefix + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return false; + if (len > static_cast(end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + return false; + } + data_offset = len_bytes; // offset past the length prefix + data_length = static_cast(len); + } else { + // For fixed-size and varint fields + int field_size = get_wire_type_size(wt, cur, end); + if (field_size < 0) return false; + data_offset = 0; + data_length = field_size; + } + return true; +} + template __device__ inline T load_le(uint8_t const* p); @@ -270,6 +338,310 @@ __global__ void scan_all_fields_kernel( } } +// ============================================================================ +// Pass 1b: Count repeated fields kernel +// ============================================================================ + +/** + * Count occurrences of repeated fields in each row. + * Also records locations of nested message fields for hierarchical processing. + */ +__global__ void count_repeated_fields_kernel( + cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, // Which depth level we're processing + repeated_field_info* repeated_info, // [num_rows * num_repeated_fields_at_this_depth] + int num_repeated_fields, // Number of repeated fields at this depth + int const* repeated_field_indices, // Indices into schema for repeated fields at this depth + field_location* nested_locations, // Locations of nested messages for next depth [num_rows * num_nested] + int num_nested_fields, // Number of nested message fields at this depth + int const* nested_field_indices, // Indices into schema for nested message fields + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + // Initialize repeated counts to 0 + for (int f = 0; f < num_repeated_fields; f++) { + repeated_info[row * num_repeated_fields + f] = {0, 0}; + } + + // Initialize nested locations to not found + for (int f = 0; f < num_nested_fields; f++) { + nested_locations[row * num_nested_fields + f] = {-1, 0}; + } + + if (in.nullable() && in.is_null(row)) { + return; + } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; + + if (start < 0 || end < start || end > child.size()) { + atomicExch(error_flag, 1); + return; + } + + uint8_t const* cur = bytes + start; + uint8_t const* stop = bytes + end; + + while (cur < stop) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, stop, key, key_bytes)) { + atomicExch(error_flag, 1); + return; + } + cur += key_bytes; + + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + if (fn == 0) { + atomicExch(error_flag, 1); + return; + } + + // Check repeated fields at this depth + for (int i = 0; i < num_repeated_fields; i++) { + int schema_idx = repeated_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + int expected_wt = schema[schema_idx].wire_type; + + // Handle both packed and unpacked encoding for repeated fields + // Packed encoding uses wire type LEN (2) even for scalar types + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + atomicExch(error_flag, 1); + return; + } + + if (is_packed) { + // Packed encoding: read length, then count elements inside + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, stop, packed_len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + + // Count elements based on type + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > stop) { + atomicExch(error_flag, 1); + return; + } + + int count = 0; + if (expected_wt == WT_VARINT) { + // Count varints in the packed data + uint8_t const* p = packed_start; + while (p < packed_end) { + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + atomicExch(error_flag, 1); + return; + } + p += vbytes; + count++; + } + } else if (expected_wt == WT_32BIT) { + count = static_cast(packed_len) / 4; + } else if (expected_wt == WT_64BIT) { + count = static_cast(packed_len) / 8; + } + + repeated_info[row * num_repeated_fields + i].count += count; + repeated_info[row * num_repeated_fields + i].total_length += static_cast(packed_len); + } else { + // Non-packed encoding: single element + int32_t data_offset, data_length; + if (!get_field_data_location(cur, stop, wt, data_offset, data_length)) { + atomicExch(error_flag, 1); + return; + } + + repeated_info[row * num_repeated_fields + i].count++; + repeated_info[row * num_repeated_fields + i].total_length += data_length; + } + } + } + + // Check nested message fields at this depth (last one wins for non-repeated) + for (int i = 0; i < num_nested_fields; i++) { + int schema_idx = nested_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + if (wt != WT_LEN) { + atomicExch(error_flag, 1); + return; + } + + uint64_t len; + int len_bytes; + if (!read_varint(cur, stop, len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + + int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; + nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, stop, wt, next)) { + atomicExch(error_flag, 1); + return; + } + cur = next; + } +} + +/** + * Scan and record all occurrences of repeated fields. + * Called after count_repeated_fields_kernel to fill in actual locations. + */ +__global__ void scan_repeated_field_occurrences_kernel( + cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int schema_idx, // Which field in schema we're scanning + int depth_level, + int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] + repeated_occurrence* occurrences, // Output: all occurrences [total_count] + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + if (in.nullable() && in.is_null(row)) { + return; + } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + auto start = in.offset_at(row) - base; + auto end = in.offset_at(row + 1) - base; + + if (start < 0 || end < start || end > child.size()) { + atomicExch(error_flag, 1); + return; + } + + uint8_t const* cur = bytes + start; + uint8_t const* stop = bytes + end; + + int target_fn = schema[schema_idx].field_number; + int target_wt = schema[schema_idx].wire_type; + int write_idx = output_offsets[row]; + + while (cur < stop) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, stop, key, key_bytes)) { + atomicExch(error_flag, 1); + return; + } + cur += key_bytes; + + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + if (fn == 0) { + atomicExch(error_flag, 1); + return; + } + + if (fn == target_fn) { + // Check for packed encoding: wire type LEN but expected non-LEN + bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); + + if (is_packed) { + // Packed encoding: multiple elements in a length-delimited blob + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, stop, packed_len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > stop) { + atomicExch(error_flag, 1); + return; + } + + // Record each element in the packed blob + if (target_wt == WT_VARINT) { + // Varints: parse each one + uint8_t const* p = packed_start; + while (p < packed_end) { + int32_t elem_offset = static_cast(p - bytes - start); + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + atomicExch(error_flag, 1); + return; + } + occurrences[write_idx] = {static_cast(row), elem_offset, vbytes}; + write_idx++; + p += vbytes; + } + } else if (target_wt == WT_32BIT) { + // Fixed 32-bit: each element is 4 bytes + uint8_t const* p = packed_start; + while (p + 4 <= packed_end) { + int32_t elem_offset = static_cast(p - bytes - start); + occurrences[write_idx] = {static_cast(row), elem_offset, 4}; + write_idx++; + p += 4; + } + } else if (target_wt == WT_64BIT) { + // Fixed 64-bit: each element is 8 bytes + uint8_t const* p = packed_start; + while (p + 8 <= packed_end) { + int32_t elem_offset = static_cast(p - bytes - start); + occurrences[write_idx] = {static_cast(row), elem_offset, 8}; + write_idx++; + p += 8; + } + } + } else if (wt == target_wt) { + // Non-packed encoding: single element + int32_t data_offset, data_length; + if (!get_field_data_location(cur, stop, wt, data_offset, data_length)) { + atomicExch(error_flag, 1); + return; + } + + int32_t abs_offset = static_cast(cur - bytes - start) + data_offset; + occurrences[write_idx] = {static_cast(row), abs_offset, data_length}; + write_idx++; + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, stop, wt, next)) { + atomicExch(error_flag, 1); + return; + } + cur = next; + } +} + // ============================================================================ // Pass 2: Extract data kernels // ============================================================================ @@ -457,292 +829,1414 @@ __global__ void extract_lengths_kernel(field_location const* locations, } // ============================================================================ -// Utility functions +// Repeated field extraction kernels // ============================================================================ -inline std::pair make_null_mask_from_valid( - rmm::device_uvector const& valid, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) +/** + * Extract repeated varint values using pre-recorded occurrences. + */ +template +__global__ void extract_repeated_varint_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + repeated_occurrence const* occurrences, + int total_occurrences, + OutT* out, + int* error_flag) { - auto begin = thrust::make_counting_iterator(0); - auto end = begin + valid.size(); - auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { return ptr[i]; }; - return cudf::detail::valid_if(begin, end, pred, stream, mr); + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_occurrences) return; + + auto const& occ = occurrences[idx]; + auto row_start = row_offsets[occ.row_idx] - base_offset; + uint8_t const* cur = message_data + row_start + occ.offset; + uint8_t const* cur_end = cur + occ.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + atomicExch(error_flag, 1); + out[idx] = OutT{}; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + out[idx] = static_cast(v); } /** - * Get the expected wire type for a given cudf type and encoding. + * Extract repeated fixed-size values using pre-recorded occurrences. */ -int get_expected_wire_type(cudf::type_id type_id, int encoding) +template +__global__ void extract_repeated_fixed_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + repeated_occurrence const* occurrences, + int total_occurrences, + OutT* out, + int* error_flag) { - switch (type_id) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - if (encoding == spark_rapids_jni::ENC_FIXED) { - return (type_id == cudf::type_id::INT32 || type_id == cudf::type_id::UINT32) ? WT_32BIT - : WT_64BIT; - } - return WT_VARINT; - case cudf::type_id::FLOAT32: return WT_32BIT; - case cudf::type_id::FLOAT64: return WT_64BIT; - case cudf::type_id::STRING: - case cudf::type_id::LIST: return WT_LEN; - default: CUDF_FAIL("Unsupported type for protobuf decoding"); + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_occurrences) return; + + auto const& occ = occurrences[idx]; + auto row_start = row_offsets[occ.row_idx] - base_offset; + uint8_t const* cur = message_data + row_start + occ.offset; + + OutT value; + if constexpr (WT == WT_32BIT) { + if (occ.length < 4) { + atomicExch(error_flag, 1); + out[idx] = OutT{}; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (occ.length < 8) { + atomicExch(error_flag, 1); + out[idx] = OutT{}; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); } + + out[idx] = value; } /** - * Create an all-null column of the specified type. + * Copy repeated variable-length data (string/bytes) using pre-recorded occurrences. */ -std::unique_ptr make_null_column(cudf::data_type dtype, - cudf::size_type num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) +__global__ void copy_repeated_varlen_data_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + repeated_occurrence const* occurrences, + int total_occurrences, + int32_t const* output_offsets, // Pre-computed output offsets for strings + char* output_data) { - if (num_rows == 0) { return cudf::make_empty_column(dtype); } + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_occurrences) return; - switch (dtype.id()) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT8: - case cudf::type_id::UINT8: - case cudf::type_id::INT16: - case cudf::type_id::UINT16: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - case cudf::type_id::FLOAT32: - case cudf::type_id::FLOAT64: { - auto data = rmm::device_buffer(cudf::size_of(dtype) * num_rows, stream, mr); - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); - return std::make_unique( - dtype, num_rows, std::move(data), std::move(null_mask), num_rows); - } - case cudf::type_id::STRING: { - // Create empty strings column with all nulls - rmm::device_uvector pairs(num_rows, stream, mr); - thrust::fill(rmm::exec_policy(stream), - pairs.begin(), - pairs.end(), - cudf::strings::detail::string_index_pair{nullptr, 0}); - return cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); - } - case cudf::type_id::LIST: { - // Create LIST with all nulls - // Offsets: all zeros - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); + auto const& occ = occurrences[idx]; + if (occ.length == 0) return; - // Empty child - auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::INT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); - - // All null mask - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + auto row_start = row_offsets[occ.row_idx] - base_offset; + uint8_t const* src = message_data + row_start + occ.offset; + char* dst = output_data + output_offsets[idx]; - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - num_rows, - std::move(null_mask), - stream, - mr); - } - default: CUDF_FAIL("Unsupported type for null column creation"); + for (int i = 0; i < occ.length; i++) { + dst[i] = static_cast(src[i]); } } -} // namespace +/** + * Extract lengths from repeated occurrences for prefix sum. + */ +__global__ void extract_repeated_lengths_kernel( + repeated_occurrence const* occurrences, + int total_occurrences, + int32_t* lengths) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_occurrences) return; + + lengths[idx] = occurrences[idx].length; +} // ============================================================================ -// Kernel to check required fields after scan pass +// Nested message scanning kernels // ============================================================================ /** - * Check if any required fields are missing (offset < 0) and set error flag. - * This is called after the scan pass to validate required field constraints. + * Scan nested message fields. + * Each row represents a nested message at a specific parent location. + * This kernel finds fields within the nested message bytes. */ -__global__ void check_required_fields_kernel( - field_location const* locations, // [num_rows * num_fields] - uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) +__global__ void scan_nested_message_fields_kernel( + uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, int num_fields, - int num_rows, + field_location* output_locations, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; + if (row >= num_parent_rows) return; for (int f = 0; f < num_fields; f++) { - if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { - // Required field is missing - set error flag - atomicExch(error_flag, 1); - return; // No need to check other fields for this row - } + output_locations[row * num_fields + f] = {-1, 0}; } -} - -/** - * Validate enum values against a set of valid values. - * If a value is not in the valid set: - * 1. Mark the field as invalid (valid[row] = false) - * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) - * - * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is - * encountered, the entire struct row is set to null (not just the enum field). - * - * The valid_values array must be sorted for binary search. - */ -__global__ void validate_enum_values_kernel( - int32_t const* values, // [num_rows] extracted enum values - bool* valid, // [num_rows] field validity flags (will be modified) - bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) - int32_t const* valid_enum_values, // sorted array of valid enum values - int num_valid_values, // size of valid_enum_values - int num_rows) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - // Skip if already invalid (field was missing) - missing field is not an enum error - if (!valid[row]) return; + auto const& parent_loc = parent_locations[row]; + if (parent_loc.offset < 0) { + return; + } - int32_t val = values[row]; + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + uint8_t const* nested_start = message_data + parent_row_start + parent_loc.offset; + uint8_t const* nested_end = nested_start + parent_loc.length; - // Binary search for the value in valid_enum_values - int left = 0; - int right = num_valid_values - 1; - bool found = false; + uint8_t const* cur = nested_start; - while (left <= right) { - int mid = left + (right - left) / 2; - if (valid_enum_values[mid] == val) { - found = true; - break; - } else if (valid_enum_values[mid] < val) { - left = mid + 1; - } else { - right = mid - 1; + while (cur < nested_end) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, nested_end, key, key_bytes)) { + atomicExch(error_flag, 1); + return; } - } + cur += key_bytes; - // If not found, mark as invalid - if (!found) { - valid[row] = false; - // Also mark the row as having an invalid enum - this will null the entire struct row - row_has_invalid_enum[row] = true; - } -} + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); -namespace spark_rapids_jni { + if (fn == 0) { + atomicExch(error_flag, 1); + return; + } -std::unique_ptr decode_protobuf_to_struct( - cudf::column_view const& binary_input, - int total_num_fields, - std::vector const& decoded_field_indices, - std::vector const& field_numbers, - std::vector const& all_types, - std::vector const& encodings, - std::vector const& is_required, - std::vector const& has_default_value, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - bool fail_on_errors) -{ - CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, - "binary_input must be a LIST column"); - cudf::lists_column_view const in_list(binary_input); - auto const child_type = in_list.child().type().id(); - CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, - "binary_input must be a LIST column"); - CUDF_EXPECTS(static_cast(all_types.size()) == total_num_fields, - "all_types size must equal total_num_fields"); - CUDF_EXPECTS(decoded_field_indices.size() == field_numbers.size(), - "decoded_field_indices and field_numbers must have the same length"); - CUDF_EXPECTS(encodings.size() == field_numbers.size(), - "encodings and field_numbers must have the same length"); - CUDF_EXPECTS(is_required.size() == field_numbers.size(), - "is_required and field_numbers must have the same length"); - CUDF_EXPECTS(has_default_value.size() == field_numbers.size(), - "has_default_value and field_numbers must have the same length"); - CUDF_EXPECTS(default_ints.size() == field_numbers.size(), - "default_ints and field_numbers must have the same length"); - CUDF_EXPECTS(default_floats.size() == field_numbers.size(), - "default_floats and field_numbers must have the same length"); - CUDF_EXPECTS(default_bools.size() == field_numbers.size(), - "default_bools and field_numbers must have the same length"); - CUDF_EXPECTS(default_strings.size() == field_numbers.size(), - "default_strings and field_numbers must have the same length"); + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == fn) { + if (wt != field_descs[f].expected_wire_type) { + atomicExch(error_flag, 1); + return; + } - auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto rows = binary_input.size(); - auto num_decoded_fields = static_cast(field_numbers.size()); + int data_offset = static_cast(cur - nested_start); - // Handle zero-row case - if (rows == 0) { - std::vector> empty_children; - empty_children.reserve(total_num_fields); - for (auto const& dt : all_types) { - empty_children.push_back(cudf::make_empty_column(dt)); + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, nested_end, len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + if (len > static_cast(nested_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + atomicExch(error_flag, 1); + return; + } + output_locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; + } else { + int field_size = get_wire_type_size(wt, cur, nested_end); + if (field_size < 0) { + atomicExch(error_flag, 1); + return; + } + output_locations[row * num_fields + f] = {data_offset, field_size}; + } + } } - return cudf::make_structs_column( - 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); - } - // Handle case with no fields to decode - if (num_decoded_fields == 0) { - std::vector> null_children; - null_children.reserve(total_num_fields); - for (auto const& dt : all_types) { - null_children.push_back(make_null_column(dt, rows, stream, mr)); + uint8_t const* next; + if (!skip_field(cur, nested_end, wt, next)) { + atomicExch(error_flag, 1); + return; } - return cudf::make_structs_column( - rows, std::move(null_children), 0, rmm::device_buffer{}, stream, mr); + cur = next; } +} - auto d_in = cudf::column_device_view::create(binary_input, stream); +// Utility function: make_null_mask_from_valid +// (Moved here to be available for repeated message child extraction) +template +inline std::pair make_null_mask_from_valid( + rmm::device_uvector const& valid, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto begin = thrust::make_counting_iterator(0); + auto end = begin + valid.size(); + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { return static_cast(ptr[i]); }; + return cudf::detail::valid_if(begin, end, pred, stream, mr); +} - // Prepare field descriptors for the scanning kernel - std::vector h_field_descs(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; i++) { - int schema_idx = decoded_field_indices[i]; - h_field_descs[i].field_number = field_numbers[i]; - h_field_descs[i].expected_wire_type = - get_expected_wire_type(all_types[schema_idx].id(), encodings[i]); - } +/** + * Scan for child fields within repeated message occurrences. + * Each occurrence is a protobuf message, and we need to find child field locations within it. + */ +__global__ void scan_repeated_message_children_kernel( + uint8_t const* message_data, + int32_t const* msg_row_offsets, // Row offset for each occurrence + field_location const* msg_locs, // Location of each message occurrence (offset within row, length) + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, // Output: [num_occurrences * num_child_fields] + int* error_flag) +{ + auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (occ_idx >= num_occurrences) return; - rmm::device_uvector d_field_descs(num_decoded_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), - h_field_descs.data(), - num_decoded_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); + // Initialize child locations to not found + for (int f = 0; f < num_child_fields; f++) { + child_locs[occ_idx * num_child_fields + f] = {-1, 0}; + } - // Allocate field locations array: [rows * num_decoded_fields] - rmm::device_uvector d_locations( - static_cast(rows) * num_decoded_fields, stream, mr); + auto const& msg_loc = msg_locs[occ_idx]; + if (msg_loc.offset < 0) return; - // Track errors - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + // Calculate absolute position of this message in the data + int32_t row_offset = msg_row_offsets[occ_idx]; + uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; + uint8_t const* msg_end = msg_start + msg_loc.length; - // Check if any field has enum validation - bool has_enum_fields = std::any_of( - enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); + uint8_t const* cur = msg_start; - // Track rows with invalid enum values (used to null entire struct row) + while (cur < msg_end) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, msg_end, key, key_bytes)) { + atomicExch(error_flag, 1); + return; + } + cur += key_bytes; + + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + if (fn == 0) { + atomicExch(error_flag, 1); + return; + } + + // Check against child field descriptors + for (int f = 0; f < num_child_fields; f++) { + if (child_descs[f].field_number == fn) { + if (wt != child_descs[f].expected_wire_type) { + // Wire type mismatch - could be OK for some cases (e.g., packed vs unpacked) + // For now, just continue + continue; + } + + int data_offset = static_cast(cur - msg_start); + + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + // Store offset (after length prefix) and length + child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, static_cast(len)}; + } else { + // For varint/fixed types, store offset and estimated length + int32_t data_length = 0; + if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (read_varint(cur, msg_end, dummy, vbytes)) { + data_length = vbytes; + } + } else if (wt == WT_32BIT) { + data_length = 4; + } else if (wt == WT_64BIT) { + data_length = 8; + } + child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; + } + // Don't break - last occurrence wins (protobuf semantics) + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + atomicExch(error_flag, 1); + return; + } + cur = next; + } +} + +/** + * Count repeated field occurrences within nested messages. + * Similar to count_repeated_fields_kernel but operates on nested message locations. + */ +__global__ void count_repeated_in_nested_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag) +{ + auto row_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row_idx >= num_rows) return; + + // Initialize counts + for (int ri = 0; ri < num_repeated; ri++) { + repeated_info[row_idx * num_repeated + ri] = {0, 0}; + } + + auto const& parent_loc = parent_locs[row_idx]; + if (parent_loc.offset < 0) return; + + cudf::size_type row_off; + row_off = row_offsets[row_idx] - base_offset; + + uint8_t const* msg_start = message_data + row_off + parent_loc.offset; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; + + while (cur < msg_end) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, msg_end, key, key_bytes)) { + atomicExch(error_flag, 1); + return; + } + cur += key_bytes; + + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + // Check if this is one of our repeated fields + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_indices[ri]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + int data_len = 0; + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + data_len = static_cast(len); + } + repeated_info[row_idx * num_repeated + ri].count++; + repeated_info[row_idx * num_repeated + ri].total_length += data_len; + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + atomicExch(error_flag, 1); + return; + } + cur = next; + } +} + +/** + * Scan for repeated field occurrences within nested messages. + */ +__global__ void scan_repeated_in_nested_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info const* repeated_info, + int num_repeated, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag) +{ + auto row_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row_idx >= num_rows) return; + + auto const& parent_loc = parent_locs[row_idx]; + if (parent_loc.offset < 0) return; + + // Calculate output offset for this row + int occ_offset = 0; + for (int r = 0; r < row_idx; r++) { + occ_offset += repeated_info[r * num_repeated].count; + } + + cudf::size_type row_off = row_offsets[row_idx] - base_offset; + + uint8_t const* msg_start = message_data + row_off + parent_loc.offset; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; + + int occ_idx = 0; + + while (cur < msg_end) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, msg_end, key, key_bytes)) { + atomicExch(error_flag, 1); + return; + } + cur += key_bytes; + + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + // Check if this is our repeated field (assuming single repeated field for simplicity) + int schema_idx = repeated_indices[0]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + int32_t data_offset = static_cast(cur - msg_start); + int32_t data_len = 0; + + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, 1); + return; + } + data_offset += len_bytes; + data_len = static_cast(len); + } else if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (read_varint(cur, msg_end, dummy, vbytes)) { + data_len = vbytes; + } + } else if (wt == WT_32BIT) { + data_len = 4; + } else if (wt == WT_64BIT) { + data_len = 8; + } + + occurrences[occ_offset + occ_idx] = {row_idx, data_offset, data_len}; + occ_idx++; + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + atomicExch(error_flag, 1); + return; + } + cur = next; + } +} + +/** + * Extract varint values from repeated field occurrences within nested messages. + */ +template +__global__ void extract_repeated_in_nested_varint_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + repeated_occurrence const* occurrences, + int total_count, + OutT* out, + int* error_flag) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto const& parent_loc = parent_locs[occ.row_idx]; + + cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; + uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; + + uint64_t val; + int vbytes; + if (!read_varint(data_ptr, data_ptr + 10, val, vbytes)) { + atomicExch(error_flag, 1); + return; + } + + if constexpr (ZigZag) { + val = (val >> 1) ^ (~(val & 1) + 1); + } + + out[idx] = static_cast(val); +} + +/** + * Extract string values from repeated field occurrences within nested messages. + */ +__global__ void extract_repeated_in_nested_string_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + repeated_occurrence const* occurrences, + int total_count, + int32_t const* str_offsets, + char* chars, + int* error_flag) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto const& parent_loc = parent_locs[occ.row_idx]; + + cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; + uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; + + int32_t out_offset = str_offsets[idx]; + for (int32_t i = 0; i < occ.length; i++) { + chars[out_offset + i] = static_cast(data_ptr[i]); + } +} + +/** + * Extract varint child fields from repeated message occurrences. + */ +template +__global__ void extract_repeated_msg_child_varint_kernel( + uint8_t const* message_data, + int32_t const* msg_row_offsets, + field_location const* msg_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + OutT* out, + bool* valid, + int num_occurrences, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (occ_idx >= num_occurrences) return; + + auto const& msg_loc = msg_locs[occ_idx]; + auto const& field_loc = child_locs[occ_idx * num_child_fields + child_idx]; + + if (msg_loc.offset < 0 || field_loc.offset < 0) { + if (has_default) { + out[occ_idx] = static_cast(default_value); + valid[occ_idx] = true; + } else { + valid[occ_idx] = false; + } + return; + } + + int32_t row_offset = msg_row_offsets[occ_idx]; + uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; + uint8_t const* cur = msg_start + field_loc.offset; + + uint64_t val; + int vbytes; + if (!read_varint(cur, cur + 10, val, vbytes)) { + atomicExch(error_flag, 1); + valid[occ_idx] = false; + return; + } + + if constexpr (ZigZag) { + val = (val >> 1) ^ (~(val & 1) + 1); + } + + out[occ_idx] = static_cast(val); + valid[occ_idx] = true; +} + +/** + * Extract fixed-size child fields from repeated message occurrences. + */ +template +__global__ void extract_repeated_msg_child_fixed_kernel( + uint8_t const* message_data, + int32_t const* msg_row_offsets, + field_location const* msg_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + OutT* out, + bool* valid, + int num_occurrences, + int* error_flag, + bool has_default = false, + OutT default_value = OutT{}) +{ + auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (occ_idx >= num_occurrences) return; + + auto const& msg_loc = msg_locs[occ_idx]; + auto const& field_loc = child_locs[occ_idx * num_child_fields + child_idx]; + + if (msg_loc.offset < 0 || field_loc.offset < 0) { + if (has_default) { + out[occ_idx] = default_value; + valid[occ_idx] = true; + } else { + valid[occ_idx] = false; + } + return; + } + + int32_t row_offset = msg_row_offsets[occ_idx]; + uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; + uint8_t const* cur = msg_start + field_loc.offset; + + OutT value; + if constexpr (WT == WT_32BIT) { + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[occ_idx] = value; + valid[occ_idx] = true; +} + +/** + * Helper to build string column for repeated message child fields. + */ +inline std::unique_ptr build_repeated_msg_child_string_column( + uint8_t const* message_data, + rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_locs, + rmm::device_uvector const& d_child_locs, + int child_idx, + int num_child_fields, + int total_count, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (total_count == 0) { + return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + } + + // Get string lengths from child_locs + std::vector h_child_locs(total_count * num_child_fields); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_child_locs.data(), d_child_locs.data(), + h_child_locs.size() * sizeof(field_location), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + std::vector h_lengths(total_count); + int32_t total_chars = 0; + for (int i = 0; i < total_count; i++) { + auto const& loc = h_child_locs[i * num_child_fields + child_idx]; + if (loc.offset >= 0) { + h_lengths[i] = loc.length; + total_chars += loc.length; + } else { + h_lengths[i] = 0; + } + } + + // Build string offsets + rmm::device_uvector str_offsets(total_count + 1, stream, mr); + std::vector h_offsets(total_count + 1); + h_offsets[0] = 0; + for (int i = 0; i < total_count; i++) { + h_offsets[i + 1] = h_offsets[i] + h_lengths[i]; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data(), h_offsets.data(), + (total_count + 1) * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Copy string data + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + std::vector h_msg_locs(total_count); + std::vector h_row_offsets(total_count); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_msg_locs.data(), d_msg_locs.data(), + total_count * sizeof(field_location), + cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_row_offsets.data(), d_msg_row_offsets.data(), + total_count * sizeof(int32_t), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Copy each string on host (not ideal but works) + std::vector h_chars(total_chars); + int char_idx = 0; + for (int i = 0; i < total_count; i++) { + auto const& field_loc = h_child_locs[i * num_child_fields + child_idx]; + if (field_loc.offset >= 0 && field_loc.length > 0) { + int32_t row_offset = h_row_offsets[i]; + int32_t msg_offset = h_msg_locs[i].offset; + uint8_t const* str_ptr = message_data + row_offset + msg_offset + field_loc.offset; + // Need to copy from device - use cudaMemcpy + CUDF_CUDA_TRY(cudaMemcpy(h_chars.data() + char_idx, str_ptr, + field_loc.length, cudaMemcpyDeviceToHost)); + char_idx += field_loc.length; + } + } + CUDF_CUDA_TRY(cudaMemcpyAsync(chars.data(), h_chars.data(), + total_chars, cudaMemcpyHostToDevice, stream.value())); + } + + // Build validity mask + rmm::device_uvector valid(total_count, stream, mr); + std::vector h_valid(total_count); + for (int i = 0; i < total_count; i++) { + h_valid[i] = (h_child_locs[i * num_child_fields + child_idx].offset >= 0) ? 1 : 0; + } + rmm::device_uvector d_valid_u8(total_count, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_u8.data(), h_valid.data(), + total_count * sizeof(uint8_t), + cudaMemcpyHostToDevice, stream.value())); + + auto [mask, null_count] = make_null_mask_from_valid(d_valid_u8, stream, mr); + + auto str_offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), rmm::device_buffer{}, 0); + return cudf::make_strings_column(total_count, std::move(str_offsets_col), chars.release(), null_count, std::move(mask)); +} + +/** + * Extract varint from nested message locations. + */ +template +__global__ void extract_nested_varint_kernel( + uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + OutT* out, + bool* valid, + int num_rows, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& parent_loc = parent_locations[row]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; + + if (parent_loc.offset < 0 || field_loc.offset < 0) { + if (has_default) { + out[row] = static_cast(default_value); + valid[row] = true; + } else { + valid[row] = false; + } + return; + } + + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; + uint8_t const* cur_end = cur + field_loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + atomicExch(error_flag, 1); + valid[row] = false; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + out[row] = static_cast(v); + valid[row] = true; +} + +/** + * Extract fixed-size from nested message locations. + */ +template +__global__ void extract_nested_fixed_kernel( + uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + OutT* out, + bool* valid, + int num_rows, + int* error_flag, + bool has_default = false, + OutT default_value = OutT{}) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& parent_loc = parent_locations[row]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; + + if (parent_loc.offset < 0 || field_loc.offset < 0) { + if (has_default) { + out[row] = default_value; + valid[row] = true; + } else { + valid[row] = false; + } + return; + } + + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; + + OutT value; + if constexpr (WT == WT_32BIT) { + if (field_loc.length < 4) { + atomicExch(error_flag, 1); + valid[row] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (field_loc.length < 8) { + atomicExch(error_flag, 1); + valid[row] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[row] = value; + valid[row] = true; +} + +/** + * Copy nested variable-length data (string/bytes). + */ +__global__ void copy_nested_varlen_data_kernel( + uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t const* output_offsets, + char* output_data, + int num_rows, + bool has_default = false, + uint8_t const* default_data = nullptr, + int32_t default_length = 0) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& parent_loc = parent_locations[row]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; + + char* dst = output_data + output_offsets[row]; + + if (parent_loc.offset < 0 || field_loc.offset < 0) { + if (has_default && default_length > 0) { + for (int i = 0; i < default_length; i++) { + dst[i] = static_cast(default_data[i]); + } + } + return; + } + + if (field_loc.length == 0) return; + + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + uint8_t const* src = message_data + parent_row_start + parent_loc.offset + field_loc.offset; + + for (int i = 0; i < field_loc.length; i++) { + dst[i] = static_cast(src[i]); + } +} + +/** + * Extract nested field lengths for prefix sum. + */ +__global__ void extract_nested_lengths_kernel( + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t* lengths, + int num_rows, + bool has_default = false, + int32_t default_length = 0) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& parent_loc = parent_locations[row]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; + + if (parent_loc.offset >= 0 && field_loc.offset >= 0) { + lengths[row] = field_loc.length; + } else if (has_default) { + lengths[row] = default_length; + } else { + lengths[row] = 0; + } +} + +/** + * Extract scalar string field lengths for prefix sum. + * For top-level STRING fields (not nested within a struct). + */ +__global__ void extract_scalar_string_lengths_kernel( + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t* lengths, + int num_rows, + bool has_default = false, + int32_t default_length = 0) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& loc = field_locations[row * num_fields + field_idx]; + + if (loc.offset >= 0) { + lengths[row] = loc.length; + } else if (has_default) { + lengths[row] = default_length; + } else { + lengths[row] = 0; + } +} + +/** + * Copy scalar string field data. + * For top-level STRING fields (not nested within a struct). + */ +__global__ void copy_scalar_string_data_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type row_base_offset, + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t const* output_offsets, + char* output_data, + int num_rows, + bool has_default = false, + uint8_t const* default_data = nullptr, + int32_t default_length = 0) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& loc = field_locations[row * num_fields + field_idx]; + + char* dst = output_data + output_offsets[row]; + + if (loc.offset < 0) { + // Field not found - use default if available + if (has_default && default_length > 0) { + for (int i = 0; i < default_length; i++) { + dst[i] = static_cast(default_data[i]); + } + } + return; + } + + if (loc.length == 0) return; + + auto row_start = row_offsets[row] - row_base_offset; + uint8_t const* src = message_data + row_start + loc.offset; + + for (int i = 0; i < loc.length; i++) { + dst[i] = static_cast(src[i]); + } +} + +// ============================================================================ +// Utility functions +// ============================================================================ + +// Note: make_null_mask_from_valid is defined earlier in the file (before scan_repeated_message_children_kernel) + +/** + * Get the expected wire type for a given cudf type and encoding. + */ +int get_expected_wire_type(cudf::type_id type_id, int encoding) +{ + switch (type_id) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + if (encoding == spark_rapids_jni::ENC_FIXED) { + return (type_id == cudf::type_id::INT32 || type_id == cudf::type_id::UINT32) ? WT_32BIT + : WT_64BIT; + } + return WT_VARINT; + case cudf::type_id::FLOAT32: return WT_32BIT; + case cudf::type_id::FLOAT64: return WT_64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: return WT_LEN; + default: CUDF_FAIL("Unsupported type for protobuf decoding"); + } +} + +/** + * Create an all-null column of the specified type. + */ +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0) { return cudf::make_empty_column(dtype); } + + switch (dtype.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT8: + case cudf::type_id::UINT8: + case cudf::type_id::INT16: + case cudf::type_id::UINT16: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: { + auto data = rmm::device_buffer(cudf::size_of(dtype) * num_rows, stream, mr); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return std::make_unique( + dtype, num_rows, std::move(data), std::move(null_mask), num_rows); + } + case cudf::type_id::STRING: { + // Create empty strings column with all nulls + rmm::device_uvector pairs(num_rows, stream, mr); + thrust::fill(rmm::exec_policy(stream), + pairs.begin(), + pairs.end(), + cudf::strings::detail::string_index_pair{nullptr, 0}); + return cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); + } + case cudf::type_id::LIST: { + // Create LIST with all nulls + // Offsets: all zeros (empty lists) + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + + // Empty child column - use INT8 as default element type + // This works because the list has 0 elements, so the child type doesn't matter for nulls + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::INT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + + // All null mask + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + num_rows, + std::move(null_mask), + stream, + mr); + } + case cudf::type_id::STRUCT: { + // Create STRUCT with all nulls and no children + // Note: This is a workaround. Proper nested struct handling requires recursive processing + // with full schema information. An empty struct with no children won't match expected + // schema for deeply nested types, but prevents crashes for unprocessed struct fields. + std::vector> empty_children; + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(empty_children), num_rows, std::move(null_mask), stream, mr); + } + default: CUDF_FAIL("Unsupported type for null column creation"); + } +} + +/** + * Create an empty column (0 rows) of the specified type. + * This handles nested types (LIST, STRUCT) that cudf::make_empty_column doesn't support. + */ +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (dtype.id()) { + case cudf::type_id::LIST: { + // Create empty list column with empty UINT8 child (Spark BinaryType maps to LIST) + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, 1, rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, 0); + // Initialize offset to 0 + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), &zero, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + } + case cudf::type_id::STRUCT: { + // Create empty struct column with no children + std::vector> empty_children; + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + default: + // For non-nested types, use cudf's make_empty_column + return cudf::make_empty_column(dtype); + } +} + +/** + * Find all child field indices for a given parent index in the schema. + * This is a commonly used pattern throughout the codebase. + * + * @param schema The schema vector (either nested_field_descriptor or device_nested_field_descriptor) + * @param num_fields Number of fields in the schema + * @param parent_idx The parent index to search for + * @return Vector of child field indices + */ +template +std::vector find_child_field_indices( + SchemaT const& schema, + int num_fields, + int parent_idx) +{ + std::vector child_indices; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == parent_idx) { + child_indices.push_back(i); + } + } + return child_indices; +} + +/** + * Recursively create an empty struct column with proper nested structure based on schema. + * This handles STRUCT children that contain their own grandchildren. + * + * @param schema The schema vector + * @param schema_output_types Output types for each schema field + * @param parent_idx Index of the parent field (whose children we want to create) + * @param num_fields Total number of fields in schema + * @param stream CUDA stream + * @param mr Memory resource + * @return Empty struct column with proper nested structure + */ +template +std::unique_ptr make_empty_struct_column_with_schema( + SchemaT const& schema, + std::vector const& schema_output_types, + int parent_idx, + int num_fields, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); + + std::vector> children; + for (int child_idx : child_indices) { + auto child_type = schema_output_types[child_idx]; + + // Recursively handle nested struct children + if (child_type.id() == cudf::type_id::STRUCT) { + children.push_back(make_empty_struct_column_with_schema( + schema, schema_output_types, child_idx, num_fields, stream, mr)); + } else { + children.push_back(make_empty_column_safe(child_type, stream, mr)); + } + } + + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); +} + +} // namespace + +// ============================================================================ +// Kernel to check required fields after scan pass +// ============================================================================ + +/** + * Check if any required fields are missing (offset < 0) and set error flag. + * This is called after the scan pass to validate required field constraints. + */ +__global__ void check_required_fields_kernel( + field_location const* locations, // [num_rows * num_fields] + uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) + int num_fields, + int num_rows, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + for (int f = 0; f < num_fields; f++) { + if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { + // Required field is missing - set error flag + atomicExch(error_flag, 1); + return; // No need to check other fields for this row + } + } +} + +/** + * Validate enum values against a set of valid values. + * If a value is not in the valid set: + * 1. Mark the field as invalid (valid[row] = false) + * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) + * + * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is + * encountered, the entire struct row is set to null (not just the enum field). + * + * The valid_values array must be sorted for binary search. + */ +__global__ void validate_enum_values_kernel( + int32_t const* values, // [num_rows] extracted enum values + bool* valid, // [num_rows] field validity flags (will be modified) + bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) + int32_t const* valid_enum_values, // sorted array of valid enum values + int num_valid_values, // size of valid_enum_values + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + // Skip if already invalid (field was missing) - missing field is not an enum error + if (!valid[row]) return; + + int32_t val = values[row]; + + // Binary search for the value in valid_enum_values + int left = 0; + int right = num_valid_values - 1; + bool found = false; + + while (left <= right) { + int mid = left + (right - left) / 2; + if (valid_enum_values[mid] == val) { + found = true; + break; + } else if (valid_enum_values[mid] < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // If not found, mark as invalid + if (!found) { + valid[row] = false; + // Also mark the row as having an invalid enum - this will null the entire struct row + row_has_invalid_enum[row] = true; + } +} + +namespace spark_rapids_jni { + +std::unique_ptr decode_protobuf_to_struct( + cudf::column_view const& binary_input, + int total_num_fields, + std::vector const& decoded_field_indices, + std::vector const& field_numbers, + std::vector const& all_types, + std::vector const& encodings, + std::vector const& is_required, + std::vector const& has_default_value, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + bool fail_on_errors) +{ + CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, + "binary_input must be a LIST column"); + cudf::lists_column_view const in_list(binary_input); + auto const child_type = in_list.child().type().id(); + CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, + "binary_input must be a LIST column"); + CUDF_EXPECTS(static_cast(all_types.size()) == total_num_fields, + "all_types size must equal total_num_fields"); + CUDF_EXPECTS(decoded_field_indices.size() == field_numbers.size(), + "decoded_field_indices and field_numbers must have the same length"); + CUDF_EXPECTS(encodings.size() == field_numbers.size(), + "encodings and field_numbers must have the same length"); + CUDF_EXPECTS(is_required.size() == field_numbers.size(), + "is_required and field_numbers must have the same length"); + CUDF_EXPECTS(has_default_value.size() == field_numbers.size(), + "has_default_value and field_numbers must have the same length"); + CUDF_EXPECTS(default_ints.size() == field_numbers.size(), + "default_ints and field_numbers must have the same length"); + CUDF_EXPECTS(default_floats.size() == field_numbers.size(), + "default_floats and field_numbers must have the same length"); + CUDF_EXPECTS(default_bools.size() == field_numbers.size(), + "default_bools and field_numbers must have the same length"); + CUDF_EXPECTS(default_strings.size() == field_numbers.size(), + "default_strings and field_numbers must have the same length"); + + auto const stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + auto rows = binary_input.size(); + auto num_decoded_fields = static_cast(field_numbers.size()); + + // Handle zero-row case + if (rows == 0) { + std::vector> empty_children; + empty_children.reserve(total_num_fields); + for (auto const& dt : all_types) { + empty_children.push_back(make_empty_column_safe(dt, stream, mr)); + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + // Handle case with no fields to decode + if (num_decoded_fields == 0) { + std::vector> null_children; + null_children.reserve(total_num_fields); + for (auto const& dt : all_types) { + null_children.push_back(make_null_column(dt, rows, stream, mr)); + } + return cudf::make_structs_column( + rows, std::move(null_children), 0, rmm::device_buffer{}, stream, mr); + } + + auto d_in = cudf::column_device_view::create(binary_input, stream); + + // Prepare field descriptors for the scanning kernel + std::vector h_field_descs(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; i++) { + int schema_idx = decoded_field_indices[i]; + h_field_descs[i].field_number = field_numbers[i]; + h_field_descs[i].expected_wire_type = + get_expected_wire_type(all_types[schema_idx].id(), encodings[i]); + } + + rmm::device_uvector d_field_descs(num_decoded_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), + h_field_descs.data(), + num_decoded_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + // Allocate field locations array: [rows * num_decoded_fields] + rmm::device_uvector d_locations( + static_cast(rows) * num_decoded_fields, stream, mr); + + // Track errors + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + // Check if any field has enum validation + bool has_enum_fields = std::any_of( + enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); + + // Track rows with invalid enum values (used to null entire struct row) // This matches Spark CPU PERMISSIVE mode behavior rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? rows : 0, stream, mr); if (has_enum_fields) { @@ -752,599 +2246,2208 @@ std::unique_ptr decode_protobuf_to_struct( } auto const threads = 256; - auto const blocks = static_cast((rows + threads - 1) / threads); + auto const blocks = static_cast((rows + threads - 1) / threads); + + // ========================================================================= + // Pass 1: Scan all messages and record field locations + // ========================================================================= + scan_all_fields_kernel<<>>( + *d_in, d_field_descs.data(), num_decoded_fields, d_locations.data(), d_error.data()); + + // ========================================================================= + // Check required fields (after scan pass) + // ========================================================================= + // Only check if any field is required to avoid unnecessary kernel launch + bool has_required_fields = + std::any_of(is_required.begin(), is_required.end(), [](bool b) { return b; }); + if (has_required_fields) { + // Copy is_required flags to device + // Note: std::vector is special (bitfield), so we convert to uint8_t + rmm::device_uvector d_is_required(num_decoded_fields, stream, mr); + std::vector h_is_required_vec(num_decoded_fields); + for (int i = 0; i < num_decoded_fields; i++) { + h_is_required_vec[i] = is_required[i] ? 1 : 0; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), + h_is_required_vec.data(), + num_decoded_fields * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + + check_required_fields_kernel<<>>( + d_locations.data(), d_is_required.data(), num_decoded_fields, rows, d_error.data()); + } + + // Get message data pointer and offsets for pass 2 + auto const* message_data = reinterpret_cast(in_list.child().data()); + auto const* list_offsets = in_list.offsets().data(); + // Get the base offset by copying from device to host + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // ========================================================================= + // Pass 2: Extract data for each field + // ========================================================================= + std::vector> all_children(total_num_fields); + int decoded_idx = 0; + + for (int schema_idx = 0; schema_idx < total_num_fields; schema_idx++) { + if (decoded_idx < num_decoded_fields && decoded_field_indices[decoded_idx] == schema_idx) { + // This field needs to be decoded + auto const dt = all_types[schema_idx]; + auto const enc = encodings[decoded_idx]; + + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_val = has_def ? (default_bools[decoded_idx] ? 1 : 0) : 0; + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_val); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::INT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + int32_t def_fixed = static_cast(def_int); + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_fixed); + } else { + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } + + // Validate enum values if this is an enum field + // enum_valid_values[decoded_idx] is non-empty for enum fields + auto const& valid_enums = enum_valid_values[decoded_idx]; + if (!valid_enums.empty()) { + // Copy valid enum values to device + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Validate enum values - unknown values will null the entire row + validate_enum_values_kernel<<>>( + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + rows); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::UINT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + uint32_t def_fixed = static_cast(def_int); + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_fixed); + } else { + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::INT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } else { + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::UINT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + uint64_t def_fixed = static_cast(def_int); + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_fixed); + } else { + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::FLOAT32: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + float def_float = has_def ? static_cast(default_floats[decoded_idx]) : 0.0f; + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_float); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::FLOAT64: { + rmm::device_uvector out(rows, stream, mr); + rmm::device_uvector valid(rows, stream, mr); + bool has_def = has_default_value[decoded_idx]; + double def_double = has_def ? default_floats[decoded_idx] : 0.0; + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + out.data(), + valid.data(), + rows, + d_error.data(), + has_def, + def_double); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + all_children[schema_idx] = + std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + break; + } + + case cudf::type_id::STRING: { + // Check for default value + bool has_def = has_default_value[decoded_idx]; + auto const& def_str = default_strings[decoded_idx]; + int32_t def_len = has_def ? static_cast(def_str.size()) : 0; + + // Copy default string to device if needed + rmm::device_uvector d_default_str(def_len, stream, mr); + if (has_def && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), + def_str.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); + } + + // Extract lengths and compute output offsets via prefix sum + rmm::device_uvector lengths(rows, stream, mr); + extract_lengths_kernel<<>>(d_locations.data(), + decoded_idx, + num_decoded_fields, + lengths.data(), + rows, + has_def, + def_len); + + rmm::device_uvector output_offsets(rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); + + // Get total size + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + output_offsets.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + stream.synchronize(); + total_chars += last_len; + + // Set the final offset + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Allocate and copy character data + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_varlen_data_kernel<<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + output_offsets.data(), + chars.data(), + rows, + has_def, + d_default_str.data(), + def_len); + } + + // Create validity mask (field found OR has default = valid) + rmm::device_uvector valid(rows, stream, mr); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(rows), + valid.begin(), + [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( + auto row) { + return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + // Create offsets column + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); + + // Create strings column using offsets + chars buffer + all_children[schema_idx] = cudf::make_strings_column( + rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + break; + } + + case cudf::type_id::LIST: { + // For protobuf bytes: create LIST directly (optimization #2) + // Check for default value + bool has_def = has_default_value[decoded_idx]; + auto const& def_bytes = default_strings[decoded_idx]; + int32_t def_len = has_def ? static_cast(def_bytes.size()) : 0; + + // Copy default bytes to device if needed + rmm::device_uvector d_default_bytes(def_len, stream, mr); + if (has_def && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), + def_bytes.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); + } + + // Extract lengths and compute output offsets via prefix sum + rmm::device_uvector lengths(rows, stream, mr); + extract_lengths_kernel<<>>(d_locations.data(), + decoded_idx, + num_decoded_fields, + lengths.data(), + rows, + has_def, + def_len); + + rmm::device_uvector output_offsets(rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); + + // Get total size + int32_t total_bytes = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, + output_offsets.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + stream.synchronize(); + total_bytes += last_len; + + // Set the final offset + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Allocate and copy byte data directly to INT8 buffer + rmm::device_uvector child_data(total_bytes, stream, mr); + if (total_bytes > 0) { + copy_varlen_data_kernel<<>>( + message_data, + list_offsets, + base_offset, + d_locations.data(), + decoded_idx, + num_decoded_fields, + output_offsets.data(), + reinterpret_cast(child_data.data()), + rows, + has_def, + d_default_bytes.data(), + def_len); + } + + // Create validity mask (field found OR has default = valid) + rmm::device_uvector valid(rows, stream, mr); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(rows), + valid.begin(), + [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( + auto row) { + return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + // Create offsets column + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); + + // Create INT8 child column directly (no intermediate strings column!) + auto child_col = std::make_unique(cudf::data_type{cudf::type_id::INT8}, + total_bytes, + child_data.release(), + rmm::device_buffer{}, + 0); + + all_children[schema_idx] = cudf::make_lists_column(rows, + std::move(offsets_col), + std::move(child_col), + null_count, + std::move(mask), + stream, + mr); + break; + } + + default: CUDF_FAIL("Unsupported output type for protobuf decoder"); + } + + decoded_idx++; + } else { + // This field is not decoded - create null column + all_children[schema_idx] = make_null_column(all_types[schema_idx], rows, stream, mr); + } + } + + // Check for errors + CUDF_CUDA_TRY(cudaPeekAtLastError()); + + // Check for any parse errors or missing required fields. + // Note: We check errors after all kernels complete rather than between kernel launches + // to avoid expensive synchronization overhead. If fail_on_errors is true and an error + // occurred, all kernels will have executed but we throw an exception here. + int h_error = 0; + CUDF_CUDA_TRY( + cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + if (fail_on_errors) { + CUDF_EXPECTS(h_error == 0, + "Malformed protobuf message, unsupported wire type, or missing required field"); + } + + // Build the final struct + // If any rows have invalid enum values, create a null mask for the struct + // This matches Spark CPU PERMISSIVE mode: unknown enum values null the entire row + cudf::size_type struct_null_count = 0; + rmm::device_buffer struct_mask{0, stream, mr}; + + if (has_enum_fields) { + // Create struct null mask: row is valid if it has NO invalid enums + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(rows), + [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { + return !row_invalid[row]; // valid if NOT invalid + }, + stream, + mr); + struct_mask = std::move(mask); + struct_null_count = null_count; + } + + return cudf::make_structs_column( + rows, std::move(all_children), struct_null_count, std::move(struct_mask), stream, mr); +} + +// ============================================================================ +// Nested protobuf decoding implementation +// ============================================================================ + +namespace { + +/** + * Helper to build a repeated scalar column (LIST of scalar type). + */ +template +std::unique_ptr build_repeated_scalar_column( + cudf::column_view const& binary_input, + device_nested_field_descriptor const& field_desc, + std::vector const& h_repeated_info, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + // Get input column's null mask to determine which output rows should be null + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + cudf::lists_column_view const in_list(binary_input); + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + // All rows have count=0, but we still need to check input nulls + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); + auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) + ? cudf::type_id::UINT8 + : static_cast(field_desc.output_type_id); + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), + input_null_count, std::move(null_mask), stream, mr); + } else { + // No input nulls, all rows get empty arrays [] + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), + 0, rmm::device_buffer{}, stream, mr); + } + } + + auto const* message_data = reinterpret_cast(in_list.child().data()); + auto const* list_offsets = in_list.offsets().data(); + + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Build list offsets from counts + rmm::device_uvector counts(num_rows, stream, mr); + std::vector h_counts(num_rows); + for (int i = 0; i < num_rows; i++) { + h_counts[i] = h_repeated_info[i].count; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), h_counts.data(), num_rows * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + + int32_t last_offset_h = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_offset_h, list_offs.data() + num_rows - 1, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream.value())); + int32_t last_count_h = h_counts[num_rows - 1]; + stream.synchronize(); + last_offset_h += last_count_h; + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &last_offset_h, sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Extract values + rmm::device_uvector values(total_count, stream, mr); + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + auto const threads = 256; + auto const blocks = (total_count + threads - 1) / threads; - // ========================================================================= - // Pass 1: Scan all messages and record field locations - // ========================================================================= - scan_all_fields_kernel<<>>( - *d_in, d_field_descs.data(), num_decoded_fields, d_locations.data(), d_error.data()); + int encoding = field_desc.encoding; + bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); + + // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) + // For integer types, use fixed kernel only if encoding is ENC_FIXED + constexpr bool is_floating_point = std::is_same_v || std::is_same_v; + bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); - // ========================================================================= - // Check required fields (after scan pass) - // ========================================================================= - // Only check if any field is required to avoid unnecessary kernel launch - bool has_required_fields = - std::any_of(is_required.begin(), is_required.end(), [](bool b) { return b; }); - if (has_required_fields) { - // Copy is_required flags to device - // Note: std::vector is special (bitfield), so we convert to uint8_t - rmm::device_uvector d_is_required(num_decoded_fields, stream, mr); - std::vector h_is_required_vec(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; i++) { - h_is_required_vec[i] = is_required[i] ? 1 : 0; + if (use_fixed_kernel) { + if constexpr (sizeof(T) == 4) { + extract_repeated_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + } else { + extract_repeated_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), - h_is_required_vec.data(), - num_decoded_fields * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); + } else if (zigzag) { + extract_repeated_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + } else { + extract_repeated_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + } - check_required_fields_kernel<<>>( - d_locations.data(), d_is_required.data(), num_decoded_fields, rows, d_error.data()); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); + auto child_col = std::make_unique( + cudf::data_type{static_cast(field_desc.output_type_id)}, total_count, values.release(), rmm::device_buffer{}, 0); + + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), + input_null_count, std::move(null_mask), stream, mr); } - // Get message data pointer and offsets for pass 2 + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); +} + +/** + * Build a repeated string/bytes column (LIST of STRING or LIST). + */ +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + device_nested_field_descriptor const& field_desc, + std::vector const& h_repeated_info, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, // true for bytes (LIST), false for string + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + // Get input column's null mask to determine which output rows should be null + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + // All rows have count=0, but we still need to check input nulls + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); + auto child_col = is_bytes + ? make_empty_column_safe(cudf::data_type{cudf::type_id::LIST}, stream, mr) // LIST + : cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), + input_null_count, std::move(null_mask), stream, mr); + } else { + // No input nulls, all rows get empty arrays [] + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), + 0, rmm::device_buffer{}, stream, mr); + } + } + + cudf::lists_column_view const in_list(binary_input); auto const* message_data = reinterpret_cast(in_list.child().data()); auto const* list_offsets = in_list.offsets().data(); - // Get the base offset by copying from device to host + cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // ========================================================================= - // Pass 2: Extract data for each field - // ========================================================================= - std::vector> all_children(total_num_fields); - int decoded_idx = 0; + // Build list offsets from counts + rmm::device_uvector counts(num_rows, stream, mr); + std::vector h_counts(num_rows); + for (int i = 0; i < num_rows; i++) { + h_counts[i] = h_repeated_info[i].count; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), h_counts.data(), num_rows * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); - for (int schema_idx = 0; schema_idx < total_num_fields; schema_idx++) { - if (decoded_idx < num_decoded_fields && decoded_field_indices[decoded_idx] == schema_idx) { - // This field needs to be decoded - auto const dt = all_types[schema_idx]; - auto const enc = encodings[decoded_idx]; + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - switch (dt.id()) { - case cudf::type_id::BOOL8: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_val = has_def ? (default_bools[decoded_idx] ? 1 : 0) : 0; - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } + int32_t last_offset_h = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_offset_h, list_offs.data() + num_rows - 1, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream.value())); + int32_t last_count_h = h_counts[num_rows - 1]; + stream.synchronize(); + last_offset_h += last_count_h; + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &last_offset_h, sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); - case cudf::type_id::INT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - int32_t def_fixed = static_cast(def_int); - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_fixed); - } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); + // Extract string lengths from occurrences + rmm::device_uvector str_lengths(total_count, stream, mr); + auto const threads = 256; + auto const blocks = (total_count + threads - 1) / threads; + extract_repeated_lengths_kernel<<>>( + d_occurrences.data(), total_count, str_lengths.data()); + + // Compute string offsets via prefix sum + rmm::device_uvector str_offsets(total_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), str_lengths.begin(), str_lengths.end(), str_offsets.begin(), 0); + + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, str_offsets.data() + total_count - 1, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, str_lengths.data() + total_count - 1, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, &total_chars, sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Copy string data + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_repeated_varlen_data_kernel<<>>( + message_data, list_offsets, base_offset, d_occurrences.data(), total_count, + str_offsets.data(), chars.data()); + } + + // Build the child column (either STRING or LIST) + std::unique_ptr child_col; + if (is_bytes) { + // Build LIST for bytes (Spark BinaryType maps to LIST) + auto str_offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), rmm::device_buffer{}, 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, total_chars, + rmm::device_buffer(chars.data(), total_chars, stream, mr), rmm::device_buffer{}, 0); + child_col = cudf::make_lists_column(total_count, std::move(str_offsets_col), std::move(bytes_child), + 0, rmm::device_buffer{}, stream, mr); + } else { + // Build STRING column + auto str_offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), rmm::device_buffer{}, 0); + child_col = cudf::make_strings_column(total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); + } + + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); + + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), + input_null_count, std::move(null_mask), stream, mr); + } + + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); +} + +/** + * Build a repeated struct column (LIST of STRUCT). + * This handles repeated message fields like: repeated Item items = 2; + * The output is ArrayType(StructType(...)) + */ +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + device_nested_field_descriptor const& field_desc, + std::vector const& h_repeated_info, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + // Child field information + std::vector const& h_device_schema, + std::vector const& child_field_indices, // Indices of child fields in schema + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + int num_child_fields = static_cast(child_field_indices.size()); + + if (total_count == 0 || num_child_fields == 0) { + // All rows have count=0 or no child fields - return list of empty structs + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); + + // Build empty struct child column with proper nested structure + int num_schema_fields = static_cast(h_device_schema.size()); + std::vector> empty_struct_children; + for (int child_schema_idx : child_field_indices) { + auto child_type = schema_output_types[child_schema_idx]; + if (child_type.id() == cudf::type_id::STRUCT) { + // Use helper to recursively build nested struct + empty_struct_children.push_back(make_empty_struct_column_with_schema( + h_device_schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr)); + } else { + empty_struct_children.push_back(make_empty_column_safe(child_type, stream, mr)); + } + } + auto empty_struct = cudf::make_structs_column(0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(empty_struct), + input_null_count, std::move(null_mask), stream, mr); + } else { + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(empty_struct), + 0, rmm::device_buffer{}, stream, mr); + } + } + + cudf::lists_column_view const in_list(binary_input); + auto const* message_data = reinterpret_cast(in_list.child().data()); + auto const* list_offsets = in_list.offsets().data(); + + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Build list offsets from counts (for the outer LIST column) + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + std::vector h_counts(num_rows); + for (int i = 0; i < num_rows; i++) { + h_counts[i] = h_repeated_info[i].count; + } + rmm::device_uvector counts(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), h_counts.data(), num_rows * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + + int32_t last_offset_h = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_offset_h, list_offs.data() + num_rows - 1, sizeof(int32_t), + cudaMemcpyDeviceToHost, stream.value())); + int32_t last_count_h = h_counts[num_rows - 1]; + stream.synchronize(); + last_offset_h += last_count_h; + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &last_offset_h, sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Copy occurrences to host for processing + std::vector h_occurrences(total_count); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_occurrences.data(), d_occurrences.data(), + total_count * sizeof(repeated_occurrence), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Build child field descriptors for scanning within each message occurrence + std::vector h_child_descs(num_child_fields); + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; + h_child_descs[ci].expected_wire_type = h_device_schema[child_schema_idx].wire_type; + } + rmm::device_uvector d_child_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), h_child_descs.data(), + num_child_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, stream.value())); + + // For each occurrence, we need to scan for child fields + // Create "virtual" parent locations from the occurrences + // Each occurrence becomes a "parent" message for child field scanning + std::vector h_msg_locs(total_count); + std::vector h_msg_row_offsets(total_count); + for (int i = 0; i < total_count; i++) { + auto const& occ = h_occurrences[i]; + // Get the row's start offset in the binary column + cudf::size_type row_offset; + CUDF_CUDA_TRY(cudaMemcpyAsync(&row_offset, list_offsets + occ.row_idx, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + h_msg_row_offsets[i] = static_cast(row_offset - base_offset); + h_msg_locs[i] = {occ.offset, occ.length}; + } + + rmm::device_uvector d_msg_locs(total_count, stream, mr); + rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_msg_locs.data(), h_msg_locs.data(), + total_count * sizeof(field_location), + cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_msg_row_offsets.data(), h_msg_row_offsets.data(), + total_count * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Scan for child fields within each message occurrence + rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + auto const threads = 256; + auto const blocks = (total_count + threads - 1) / threads; + + // Use a custom kernel to scan child fields within message occurrences + // This is similar to scan_nested_message_fields_kernel but operates on occurrences + scan_repeated_message_children_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), total_count, + d_child_descs.data(), num_child_fields, d_child_locs.data(), d_error.data()); + + // Copy child locations to host + std::vector h_child_locs(total_count * num_child_fields); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_child_locs.data(), d_child_locs.data(), + h_child_locs.size() * sizeof(field_location), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Extract child field values - build one column per child field + std::vector> struct_children; + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = h_device_schema[child_schema_idx].encoding; + bool has_def = h_device_schema[child_schema_idx].has_default_value; + + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_val); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT32: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_repeated_msg_child_fixed_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT64: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_repeated_msg_child_fixed_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_int); + } else { + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT32: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; + extract_repeated_msg_child_fixed_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_float); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT64: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + double def_double = has_def ? default_floats[child_schema_idx] : 0.0; + extract_repeated_msg_child_fixed_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), + total_count, d_error.data(), has_def, def_double); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::STRING: { + // For strings, we need a two-pass approach: first get lengths, then copy data + struct_children.push_back( + build_repeated_msg_child_string_column( + message_data, d_msg_row_offsets, d_msg_locs, + d_child_locs, ci, num_child_fields, total_count, d_error, stream, mr)); + break; + } + case cudf::type_id::STRUCT: { + // Nested struct inside repeated message - need to extract grandchild fields + int num_schema_fields = static_cast(h_device_schema.size()); + auto grandchild_indices = find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); + + if (grandchild_indices.empty()) { + // No grandchildren - create empty struct column + struct_children.push_back(cudf::make_structs_column( + total_count, std::vector>{}, 0, rmm::device_buffer{}, stream, mr)); + } else { + // Build grandchild columns + // For each occurrence, the nested struct location is in child_locs[occ * num_child_fields + ci] + // We need to scan within each nested struct for grandchild fields + + // Build grandchild field descriptors + int num_grandchildren = static_cast(grandchild_indices.size()); + std::vector h_gc_descs(num_grandchildren); + for (int gci = 0; gci < num_grandchildren; gci++) { + int gc_schema_idx = grandchild_indices[gci]; + h_gc_descs[gci].field_number = h_device_schema[gc_schema_idx].field_number; + h_gc_descs[gci].expected_wire_type = h_device_schema[gc_schema_idx].wire_type; } + rmm::device_uvector d_gc_descs(num_grandchildren, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_gc_descs.data(), h_gc_descs.data(), + num_grandchildren * sizeof(field_descriptor), + cudaMemcpyHostToDevice, stream.value())); + + // Create nested struct locations from child_locs + // Each occurrence's nested struct is at child_locs[occ * num_child_fields + ci] + std::vector h_nested_locs(total_count); + std::vector h_nested_row_offsets(total_count); + for (int occ = 0; occ < total_count; occ++) { + auto const& nested_loc = h_child_locs[occ * num_child_fields + ci]; + auto const& msg_loc = h_msg_locs[occ]; + h_nested_row_offsets[occ] = h_msg_row_offsets[occ] + msg_loc.offset; + h_nested_locs[occ] = nested_loc; + } + + rmm::device_uvector d_nested_locs(total_count, stream, mr); + rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_locs.data(), h_nested_locs.data(), + total_count * sizeof(field_location), + cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_row_offsets.data(), h_nested_row_offsets.data(), + total_count * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Scan for grandchild fields + rmm::device_uvector d_gc_locs(total_count * num_grandchildren, stream, mr); + scan_repeated_message_children_kernel<<>>( + message_data, d_nested_row_offsets.data(), d_nested_locs.data(), total_count, + d_gc_descs.data(), num_grandchildren, d_gc_locs.data(), d_error.data()); + + // Copy grandchild locations to host + std::vector h_gc_locs(total_count * num_grandchildren); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_gc_locs.data(), d_gc_locs.data(), + h_gc_locs.size() * sizeof(field_location), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Extract grandchild values + std::vector> grandchild_cols; + for (int gci = 0; gci < num_grandchildren; gci++) { + int gc_schema_idx = grandchild_indices[gci]; + auto const gc_dt = schema_output_types[gc_schema_idx]; + auto const gc_enc = h_device_schema[gc_schema_idx].encoding; + bool gc_has_def = h_device_schema[gc_schema_idx].has_default_value; + + switch (gc_dt.id()) { + case cudf::type_id::INT32: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + int64_t def_val = gc_has_def ? default_ints[gc_schema_idx] : 0; + if (gc_enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_nested_row_offsets.data(), d_nested_locs.data(), + d_gc_locs.data(), gci, num_grandchildren, out.data(), valid.data(), + total_count, d_error.data(), gc_has_def, def_val); + } else { + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_nested_row_offsets.data(), d_nested_locs.data(), + d_gc_locs.data(), gci, num_grandchildren, out.data(), valid.data(), + total_count, d_error.data(), gc_has_def, def_val); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + grandchild_cols.push_back(std::make_unique( + gc_dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT64: { + rmm::device_uvector out(total_count, stream, mr); + rmm::device_uvector valid(total_count, stream, mr); + int64_t def_val = gc_has_def ? default_ints[gc_schema_idx] : 0; + extract_repeated_msg_child_varint_kernel<<>>( + message_data, d_nested_row_offsets.data(), d_nested_locs.data(), + d_gc_locs.data(), gci, num_grandchildren, out.data(), valid.data(), + total_count, d_error.data(), gc_has_def, def_val); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + grandchild_cols.push_back(std::make_unique( + gc_dt, total_count, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::STRING: { + grandchild_cols.push_back( + build_repeated_msg_child_string_column( + message_data, d_nested_row_offsets, d_nested_locs, + d_gc_locs, gci, num_grandchildren, total_count, d_error, stream, mr)); + break; + } + default: + // Unsupported grandchild type - create null column + grandchild_cols.push_back(make_null_column(gc_dt, total_count, stream, mr)); + break; + } + } + + // Build the nested struct column + auto nested_struct_col = cudf::make_structs_column( + total_count, std::move(grandchild_cols), 0, rmm::device_buffer{}, stream, mr); + struct_children.push_back(std::move(nested_struct_col)); + } + break; + } + default: + // Unsupported child type - create null column + struct_children.push_back(make_null_column(dt, total_count, stream, mr)); + break; + } + } - // Validate enum values if this is an enum field - // enum_valid_values[decoded_idx] is non-empty for enum fields - auto const& valid_enums = enum_valid_values[decoded_idx]; - if (!valid_enums.empty()) { - // Copy valid enum values to device - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + // Build the struct column from child columns + auto struct_col = cudf::make_structs_column(total_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); - // Validate enum values - unknown values will null the entire row - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - rows); - } + // Build the list offsets column + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; + // Build the final LIST column + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), + input_null_count, std::move(null_mask), stream, mr); + } + + return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); +} + +} // anonymous namespace + +std::unique_ptr decode_nested_protobuf_to_struct( + cudf::column_view const& binary_input, + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + bool fail_on_errors) +{ + CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, + "binary_input must be a LIST column"); + cudf::lists_column_view const in_list(binary_input); + auto const child_type = in_list.child().type().id(); + CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, + "binary_input must be a LIST column"); + + auto const stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + auto num_rows = binary_input.size(); + auto num_fields = static_cast(schema.size()); + + if (num_rows == 0 || num_fields == 0) { + // Build empty struct based on top-level fields with proper nested structure + std::vector> empty_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + auto field_type = schema_output_types[i]; + if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { + // Repeated message field - build empty LIST with proper struct element + rmm::device_uvector offsets(1, stream, mr); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets.data(), &zero, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); + auto empty_struct = make_empty_struct_column_with_schema( + schema, schema_output_types, i, num_fields, stream, mr); + empty_children.push_back(cudf::make_lists_column(0, std::move(offsets_col), std::move(empty_struct), + 0, rmm::device_buffer{}, stream, mr)); + } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { + // Non-repeated nested message field + empty_children.push_back(make_empty_struct_column_with_schema( + schema, schema_output_types, i, num_fields, stream, mr)); + } else { + empty_children.push_back(make_empty_column_safe(field_type, stream, mr)); } + } + } + return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } - case cudf::type_id::UINT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - uint32_t def_fixed = static_cast(def_int); - if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_fixed); - } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } + // Copy schema to device + std::vector h_device_schema(num_fields); + for (int i = 0; i < num_fields; i++) { + h_device_schema[i] = { + schema[i].field_number, + schema[i].parent_idx, + schema[i].depth, + schema[i].wire_type, + static_cast(schema[i].output_type), + schema[i].encoding, + schema[i].is_repeated, + schema[i].is_required, + schema[i].has_default_value + }; + } + + rmm::device_uvector d_schema(num_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), h_device_schema.data(), + num_fields * sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, stream.value())); + + auto d_in = cudf::column_device_view::create(binary_input, stream); + + // Identify repeated and nested fields at depth 0 + std::vector repeated_field_indices; + std::vector nested_field_indices; + std::vector scalar_field_indices; + + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { // Top-level fields only + if (schema[i].is_repeated) { + repeated_field_indices.push_back(i); + } else if (schema[i].output_type == cudf::type_id::STRUCT) { + nested_field_indices.push_back(i); + } else { + scalar_field_indices.push_back(i); + } + } + } + + int num_repeated = static_cast(repeated_field_indices.size()); + int num_nested = static_cast(nested_field_indices.size()); + int num_scalar = static_cast(scalar_field_indices.size()); + + // Error flag + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + auto const threads = 256; + auto const blocks = static_cast((num_rows + threads - 1) / threads); + + // Allocate for counting repeated fields + rmm::device_uvector d_repeated_info( + num_repeated > 0 ? static_cast(num_rows) * num_repeated : 1, stream, mr); + rmm::device_uvector d_nested_locations( + num_nested > 0 ? static_cast(num_rows) * num_nested : 1, stream, mr); + + rmm::device_uvector d_repeated_indices(num_repeated > 0 ? num_repeated : 1, stream, mr); + rmm::device_uvector d_nested_indices(num_nested > 0 ? num_nested : 1, stream, mr); + + if (num_repeated > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_repeated_indices.data(), repeated_field_indices.data(), + num_repeated * sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } + if (num_nested > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_indices.data(), nested_field_indices.data(), + num_nested * sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } + + // Count repeated fields at depth 0 + if (num_repeated > 0 || num_nested > 0) { + count_repeated_fields_kernel<<>>( + *d_in, + d_schema.data(), + num_fields, + 0, // depth_level + d_repeated_info.data(), + num_repeated, + d_repeated_indices.data(), + d_nested_locations.data(), + num_nested, + d_nested_indices.data(), + d_error.data()); + } + + // For scalar fields at depth 0, use the existing scan_all_fields_kernel + // Use a map to store columns by schema index, then assemble in order at the end + std::map> column_map; + + // Process scalar fields using existing infrastructure + if (num_scalar > 0) { + std::vector h_field_descs(num_scalar); + for (int i = 0; i < num_scalar; i++) { + int schema_idx = scalar_field_indices[i]; + h_field_descs[i].field_number = schema[schema_idx].field_number; + h_field_descs[i].expected_wire_type = schema[schema_idx].wire_type; + } + + rmm::device_uvector d_field_descs(num_scalar, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), h_field_descs.data(), + num_scalar * sizeof(field_descriptor), + cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector d_locations( + static_cast(num_rows) * num_scalar, stream, mr); + + scan_all_fields_kernel<<>>( + *d_in, d_field_descs.data(), num_scalar, d_locations.data(), d_error.data()); + + // Extract scalar values (reusing existing extraction logic) + cudf::lists_column_view const in_list_view(binary_input); + auto const* message_data = reinterpret_cast(in_list_view.child().data()); + auto const* list_offsets = in_list_view.offsets().data(); + + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + for (int i = 0; i < num_scalar; i++) { + int schema_idx = scalar_field_indices[i]; + auto const dt = schema_output_types[schema_idx]; + auto const enc = schema[schema_idx].encoding; + bool has_def = schema[schema_idx].has_default_value; + + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_val = has_def ? (default_bools[schema_idx] ? 1 : 0) : 0; + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_val); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); break; } - - case cudf::type_id::INT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; + case cudf::type_id::INT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, static_cast(def_int)); } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); break; } - - case cudf::type_id::UINT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - uint64_t def_fixed = static_cast(def_int); - if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_fixed); + case cudf::type_id::INT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); break; } - case cudf::type_id::FLOAT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - float def_float = has_def ? static_cast(default_floats[decoded_idx]) : 0.0f; - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_float); + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + float def_float = has_def ? static_cast(default_floats[schema_idx]) : 0.0f; + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_float); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); break; } - case cudf::type_id::FLOAT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - double def_double = has_def ? default_floats[decoded_idx] : 0.0; - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_double); + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + double def_double = has_def ? default_floats[schema_idx] : 0.0; + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_double); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); break; } - case cudf::type_id::STRING: { - // Check for default value - bool has_def = has_default_value[decoded_idx]; - auto const& def_str = default_strings[decoded_idx]; - int32_t def_len = has_def ? static_cast(def_str.size()) : 0; + // Extract top-level STRING scalar field + bool has_def_str = has_def && !default_strings[schema_idx].empty(); + auto const& def_str = default_strings[schema_idx]; + int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; - // Copy default string to device if needed rmm::device_uvector d_default_str(def_len, stream, mr); - if (has_def && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), - def_str.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); + if (has_def_str && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, + cudaMemcpyHostToDevice, stream.value())); } - // Extract lengths and compute output offsets via prefix sum - rmm::device_uvector lengths(rows, stream, mr); - extract_lengths_kernel<<>>(d_locations.data(), - decoded_idx, - num_decoded_fields, - lengths.data(), - rows, - has_def, - def_len); + // Extract string lengths + rmm::device_uvector lengths(num_rows, stream, mr); + extract_scalar_string_lengths_kernel<<>>( + d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_str, def_len); - rmm::device_uvector output_offsets(rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); + // Compute offsets via prefix sum + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); - // Get total size int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - output_offsets.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - // Set the final offset - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Allocate and copy character data + // Copy string data rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_varlen_data_kernel<<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - output_offsets.data(), - chars.data(), - rows, - has_def, - d_default_str.data(), - def_len); + copy_scalar_string_data_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + output_offsets.data(), chars.data(), num_rows, has_def_str, + d_default_str.data(), def_len); } - // Create validity mask (field found OR has default = valid) - rmm::device_uvector valid(rows, stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(rows), - valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( - auto row) { - return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; - }); + // Build validity mask + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_str; + }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - // Create offsets column auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - - // Create strings column using offsets + chars buffer - all_children[schema_idx] = cudf::make_strings_column( - rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + column_map[schema_idx] = cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); break; } + default: + // For LIST (bytes) and other unsupported types, create placeholder columns + column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); + break; + } + } + } - case cudf::type_id::LIST: { - // For protobuf bytes: create LIST directly (optimization #2) - // Check for default value - bool has_def = has_default_value[decoded_idx]; - auto const& def_bytes = default_strings[decoded_idx]; - int32_t def_len = has_def ? static_cast(def_bytes.size()) : 0; + // Process repeated fields + if (num_repeated > 0) { + std::vector h_repeated_info(static_cast(num_rows) * num_repeated); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_repeated_info.data(), d_repeated_info.data(), + h_repeated_info.size() * sizeof(repeated_field_info), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); - // Copy default bytes to device if needed - rmm::device_uvector d_default_bytes(def_len, stream, mr); - if (has_def && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), - def_bytes.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); + cudf::lists_column_view const in_list_view(binary_input); + auto const* list_offsets = in_list_view.offsets().data(); + + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_field_indices[ri]; + auto element_type = schema_output_types[schema_idx]; + + // Get per-row info for this repeated field + std::vector field_info(num_rows); + int total_count = 0; + for (int row = 0; row < num_rows; row++) { + field_info[row] = h_repeated_info[row * num_repeated + ri]; + total_count += field_info[row].count; + } + + if (total_count > 0) { + // Build offsets for occurrence scanning + rmm::device_uvector d_occ_offsets(num_rows + 1, stream, mr); + std::vector h_occ_offsets(num_rows + 1); + h_occ_offsets[0] = 0; + for (int row = 0; row < num_rows; row++) { + h_occ_offsets[row + 1] = h_occ_offsets[row] + field_info[row].count; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data(), h_occ_offsets.data(), + (num_rows + 1) * sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + + // Scan for all occurrences + rmm::device_uvector d_occurrences(total_count, stream, mr); + scan_repeated_field_occurrences_kernel<<>>( + *d_in, d_schema.data(), schema_idx, 0, d_occ_offsets.data(), + d_occurrences.data(), d_error.data()); + + // Build the appropriate column type based on element type + // For now, support scalar repeated fields + auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); + + // The output_type in schema is the LIST type, but we need element type + // For repeated int32, output_type should indicate the element is INT32 + switch (child_type_id) { + case cudf::type_id::INT32: + column_map[schema_idx] = build_repeated_scalar_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, stream, mr); + break; + case cudf::type_id::INT64: + column_map[schema_idx] = build_repeated_scalar_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, stream, mr); + break; + case cudf::type_id::FLOAT32: + column_map[schema_idx] = build_repeated_scalar_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, stream, mr); + break; + case cudf::type_id::FLOAT64: + column_map[schema_idx] = build_repeated_scalar_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, stream, mr); + break; + case cudf::type_id::BOOL8: + column_map[schema_idx] = build_repeated_scalar_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, stream, mr); + break; + case cudf::type_id::STRING: + column_map[schema_idx] = build_repeated_string_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, false, stream, mr); + break; + case cudf::type_id::LIST: // bytes as LIST + column_map[schema_idx] = build_repeated_string_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, true, stream, mr); + break; + case cudf::type_id::STRUCT: { + // Repeated message field - ArrayType(StructType) + auto child_field_indices = find_child_field_indices(schema, num_fields, schema_idx); + if (child_field_indices.empty()) { + // No child fields - create null column + column_map[schema_idx] = make_null_column(element_type, num_rows, stream, mr); + } else { + column_map[schema_idx] = build_repeated_struct_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, h_device_schema, child_field_indices, + schema_output_types, default_ints, default_floats, default_bools, + stream, mr); + } + break; } + default: + // Unsupported element type - create null column + column_map[schema_idx] = make_null_column(element_type, num_rows, stream, mr); + break; + } + } else { + // All rows have count=0 - create list of empty elements + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); + + // Build appropriate empty child column + std::unique_ptr child_col; + auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); + if (child_type_id == cudf::type_id::STRUCT) { + // Use helper to build empty struct with proper nested structure + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(schema_output_types[schema_idx], stream, mr); + } + + auto const input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), input_null_count, std::move(null_mask), stream, mr); + } else { + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + } + } + } + } - // Extract lengths and compute output offsets via prefix sum - rmm::device_uvector lengths(rows, stream, mr); - extract_lengths_kernel<<>>(d_locations.data(), - decoded_idx, - num_decoded_fields, - lengths.data(), - rows, - has_def, - def_len); + // Process nested struct fields (Phase 2) + if (num_nested > 0) { + // Copy nested locations to host for processing + std::vector h_nested_locations(static_cast(num_rows) * num_nested); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_nested_locations.data(), d_nested_locations.data(), + h_nested_locations.size() * sizeof(field_location), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); - rmm::device_uvector output_offsets(rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); + cudf::lists_column_view const in_list_view(binary_input); + auto const* message_data = reinterpret_cast(in_list_view.child().data()); + auto const* list_offsets = in_list_view.offsets().data(); - // Get total size - int32_t total_bytes = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, - output_offsets.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + for (int ni = 0; ni < num_nested; ni++) { + int parent_schema_idx = nested_field_indices[ni]; + + // Find child fields of this nested message + auto child_field_indices = find_child_field_indices(schema, num_fields, parent_schema_idx); + + if (child_field_indices.empty()) { + // No child fields - create empty struct + column_map[parent_schema_idx] = make_null_column( + schema_output_types[parent_schema_idx], num_rows, stream, mr); + continue; + } + + int num_child_fields = static_cast(child_field_indices.size()); + + // Build field descriptors for child fields + std::vector h_child_field_descs(num_child_fields); + for (int i = 0; i < num_child_fields; i++) { + int child_idx = child_field_indices[i]; + h_child_field_descs[i].field_number = schema[child_idx].field_number; + h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; + } + + rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), h_child_field_descs.data(), + num_child_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, stream.value())); + + // Prepare parent locations for this nested field + rmm::device_uvector d_parent_locs(num_rows, stream, mr); + std::vector h_parent_locs(num_rows); + for (int row = 0; row < num_rows; row++) { + h_parent_locs[row] = h_nested_locations[row * num_nested + ni]; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), h_parent_locs.data(), + num_rows * sizeof(field_location), + cudaMemcpyHostToDevice, stream.value())); + + // Scan for child fields within nested messages + rmm::device_uvector d_child_locations( + static_cast(num_rows) * num_child_fields, stream, mr); + + scan_nested_message_fields_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_child_field_descs.data(), num_child_fields, d_child_locations.data(), d_error.data()); + + // Extract child field values + std::vector> struct_children; + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = schema[child_schema_idx].encoding; + bool has_def = schema[child_schema_idx].has_default_value; + bool is_repeated = schema[child_schema_idx].is_repeated; + + // Check if this is a repeated field (ArrayType) + if (is_repeated) { + // Handle repeated field inside nested message + auto elem_type_id = schema[child_schema_idx].output_type; + + // Copy child locations to host + std::vector h_rep_parent_locs(num_rows); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_rep_parent_locs.data(), d_parent_locs.data(), + num_rows * sizeof(field_location), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - total_bytes += last_len; + + // Count repeated field occurrences for each row + rmm::device_uvector d_rep_info(num_rows, stream, mr); + + std::vector rep_indices = {0}; + rmm::device_uvector d_rep_indices(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), rep_indices.data(), + sizeof(int), cudaMemcpyHostToDevice, stream.value())); + + device_nested_field_descriptor rep_desc; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + + std::vector h_rep_schema = {rep_desc}; + rmm::device_uvector d_rep_schema(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), h_rep_schema.data(), + sizeof(device_nested_field_descriptor), cudaMemcpyHostToDevice, stream.value())); + + count_repeated_in_nested_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), d_error.data()); + + std::vector h_rep_info(num_rows); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_rep_info.data(), d_rep_info.data(), + num_rows * sizeof(repeated_field_info), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + int total_rep_count = 0; + for (int row = 0; row < num_rows; row++) { + total_rep_count += h_rep_info[row].count; + } + + if (total_rep_count == 0) { + rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), list_offsets_vec.begin(), list_offsets_vec.end(), 0); + auto list_offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offsets_vec.release(), rmm::device_buffer{}, 0); + auto child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + struct_children.push_back(cudf::make_lists_column( + num_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr)); + } else { + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + scan_repeated_in_nested_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), + d_rep_occs.data(), d_error.data()); + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + std::vector h_list_offs(num_rows + 1); + h_list_offs[0] = 0; + for (int row = 0; row < num_rows; row++) { + h_list_offs[row + 1] = h_list_offs[row] + h_rep_info[row].count; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data(), h_list_offs.data(), + (num_rows + 1) * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + std::unique_ptr child_values; + if (elem_type_id == cudf::type_id::INT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::INT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::STRING) { + std::vector h_rep_occs(total_rep_count); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_rep_occs.data(), d_rep_occs.data(), + total_rep_count * sizeof(repeated_occurrence), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + int32_t total_chars = 0; + std::vector h_str_offs(total_rep_count + 1); + h_str_offs[0] = 0; + for (int i = 0; i < total_rep_count; i++) { + h_str_offs[i + 1] = h_str_offs[i] + h_rep_occs[i].length; + total_chars += h_rep_occs[i].length; + } + + rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data(), h_str_offs.data(), + (total_rep_count + 1) * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + extract_repeated_in_nested_string_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, str_offs.data(), chars.data(), d_error.data()); + } + + auto str_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, str_offs.release(), rmm::device_buffer{}, 0); + child_values = cudf::make_strings_column(total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + } else { + child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + + auto list_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); + struct_children.push_back(cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}, stream, mr)); + } + continue; // Skip the switch statement below + } - // Set the final offset - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_val); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } else { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_float); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + double def_double = has_def ? default_floats[child_schema_idx] : 0.0; + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_double); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::STRING: { + bool has_def_str = has_def && !default_strings[child_schema_idx].empty(); + auto const& def_str = default_strings[child_schema_idx]; + int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; - // Allocate and copy byte data directly to INT8 buffer - rmm::device_uvector child_data(total_bytes, stream, mr); - if (total_bytes > 0) { - copy_varlen_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - output_offsets.data(), - reinterpret_cast(child_data.data()), - rows, - has_def, - d_default_bytes.data(), - def_len); + rmm::device_uvector d_default_str(def_len, stream, mr); + if (has_def_str && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, + cudaMemcpyHostToDevice, stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + extract_nested_lengths_kernel<<>>( + d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, + lengths.data(), num_rows, has_def_str, def_len); + + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); + + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_nested_varlen_data_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, output_offsets.data(), + chars.data(), num_rows, has_def_str, d_default_str.data(), def_len); + } + + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, num_child_fields, has_def_str] __device__(auto row) { + return (plocs[row].offset >= 0 && + flocs[row * num_child_fields + ci].offset >= 0) || has_def_str; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + struct_children.push_back(cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); + break; } + case cudf::type_id::STRUCT: { + // Recursively process nested struct (depth > 1) + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + break; + } + int num_gc = static_cast(gc_indices.size()); - // Create validity mask (field found OR has default = valid) - rmm::device_uvector valid(rows, stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(rows), - valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( - auto row) { - return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + // Get child struct locations for grandchild scanning + // IMPORTANT: Need to compute ABSOLUTE offsets (relative to row start) + // d_child_locations contains offsets relative to parent message (Middle) + // We need: child_offset_in_row = parent_offset_in_row + child_offset_in_parent + std::vector h_parent_locs(num_rows); + std::vector h_child_locs_rel(num_rows); + CUDF_CUDA_TRY(cudaMemcpyAsync(h_parent_locs.data(), d_parent_locs.data(), + num_rows * sizeof(field_location), cudaMemcpyDeviceToHost, stream.value())); + for (int row = 0; row < num_rows; row++) { + CUDF_CUDA_TRY(cudaMemcpyAsync(&h_child_locs_rel[row], + d_child_locations.data() + row * num_child_fields + ci, + sizeof(field_location), cudaMemcpyDeviceToHost, stream.value())); + } + stream.synchronize(); + + // Compute absolute offsets + std::vector h_gc_parent_abs(num_rows); + for (int row = 0; row < num_rows; row++) { + if (h_parent_locs[row].offset >= 0 && h_child_locs_rel[row].offset >= 0) { + // Absolute offset = parent offset + child's relative offset + h_gc_parent_abs[row].offset = h_parent_locs[row].offset + h_child_locs_rel[row].offset; + h_gc_parent_abs[row].length = h_child_locs_rel[row].length; + } else { + h_gc_parent_abs[row] = {-1, 0}; + } + } + + rmm::device_uvector d_gc_parent(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_gc_parent.data(), h_gc_parent_abs.data(), + num_rows * sizeof(field_location), cudaMemcpyHostToDevice, stream.value())); - // Create offsets column - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); + // Build grandchild field descriptors + std::vector h_gc_descs(num_gc); + for (int gi = 0; gi < num_gc; gi++) { + h_gc_descs[gi].field_number = schema[gc_indices[gi]].field_number; + h_gc_descs[gi].expected_wire_type = schema[gc_indices[gi]].wire_type; + } + rmm::device_uvector d_gc_descs(num_gc, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_gc_descs.data(), h_gc_descs.data(), + num_gc * sizeof(field_descriptor), cudaMemcpyHostToDevice, stream.value())); - // Create INT8 child column directly (no intermediate strings column!) - auto child_col = std::make_unique(cudf::data_type{cudf::type_id::INT8}, - total_bytes, - child_data.release(), - rmm::device_buffer{}, - 0); + // Scan for grandchild fields + rmm::device_uvector d_gc_locs(num_rows * num_gc, stream, mr); + scan_nested_message_fields_kernel<<>>( + message_data, list_offsets, base_offset, d_gc_parent.data(), num_rows, + d_gc_descs.data(), num_gc, d_gc_locs.data(), d_error.data()); - all_children[schema_idx] = cudf::make_lists_column(rows, - std::move(offsets_col), - std::move(child_col), - null_count, - std::move(mask), - stream, - mr); - break; - } + // Extract grandchild values (handle scalar types only) + std::vector> gc_cols; + for (int gi = 0; gi < num_gc; gi++) { + int gc_idx = gc_indices[gi]; + auto gc_dt = schema_output_types[gc_idx]; + bool gc_def = schema[gc_idx].has_default_value; + if (gc_dt.id() == cudf::type_id::INT32) { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector val(num_rows, stream, mr); + int64_t dv = gc_def ? default_ints[gc_idx] : 0; + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_gc_parent.data(), + d_gc_locs.data(), gi, num_gc, out.data(), val.data(), num_rows, d_error.data(), gc_def, dv); + auto [m, nc] = make_null_mask_from_valid(val, stream, mr); + gc_cols.push_back(std::make_unique(gc_dt, num_rows, out.release(), std::move(m), nc)); + } else { + gc_cols.push_back(make_null_column(gc_dt, num_rows, stream, mr)); + } + } - default: CUDF_FAIL("Unsupported output type for protobuf decoder"); + // Build nested struct validity + rmm::device_uvector ns_valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), ns_valid.begin(), + [p = d_parent_locs.data(), c = d_child_locations.data(), ci, ncf = num_child_fields] __device__(auto r) { + return p[r].offset >= 0 && c[r * ncf + ci].offset >= 0; + }); + auto [ns_mask, ns_nc] = make_null_mask_from_valid(ns_valid, stream, mr); + struct_children.push_back(cudf::make_structs_column(num_rows, std::move(gc_cols), ns_nc, std::move(ns_mask), stream, mr)); + break; + } + default: + // For unsupported types, create null columns + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + break; + } } - decoded_idx++; - } else { - // This field is not decoded - create null column - all_children[schema_idx] = make_null_column(all_types[schema_idx], rows, stream, mr); + // Build struct validity based on parent location + rmm::device_uvector struct_valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + struct_valid.begin(), + [plocs = d_parent_locs.data()] __device__(auto row) { + return plocs[row].offset >= 0; + }); + auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); + + column_map[parent_schema_idx] = cudf::make_structs_column( + num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); + } + } + + // Assemble top_level_children in schema order (not processing order) + std::vector> top_level_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { // Top-level field + auto it = column_map.find(i); + if (it != column_map.end()) { + top_level_children.push_back(std::move(it->second)); + } else { + // Field not processed - create null column + top_level_children.push_back(make_null_column(schema_output_types[i], num_rows, stream, mr)); + } } } // Check for errors CUDF_CUDA_TRY(cudaPeekAtLastError()); - - // Check for any parse errors or missing required fields. - // Note: We check errors after all kernels complete rather than between kernel launches - // to avoid expensive synchronization overhead. If fail_on_errors is true and an error - // occurred, all kernels will have executed but we throw an exception here. int h_error = 0; - CUDF_CUDA_TRY( - cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, - "Malformed protobuf message, unsupported wire type, or missing required field"); - } - - // Build the final struct - // If any rows have invalid enum values, create a null mask for the struct - // This matches Spark CPU PERMISSIVE mode: unknown enum values null the entire row - cudf::size_type struct_null_count = 0; - rmm::device_buffer struct_mask{0, stream, mr}; - - if (has_enum_fields) { - // Create struct null mask: row is valid if it has NO invalid enums - auto [mask, null_count] = cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(rows), - [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { - return !row_invalid[row]; // valid if NOT invalid - }, - stream, - mr); - struct_mask = std::move(mask); - struct_null_count = null_count; + CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); } - return cudf::make_structs_column( - rows, std::move(all_children), struct_null_count, std::move(struct_mask), stream, mr); + return cudf::make_structs_column(num_rows, std::move(top_level_children), 0, rmm::device_buffer{}, stream, mr); } } // namespace spark_rapids_jni diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 0e398af39d..30a79c95c7 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -30,6 +30,25 @@ constexpr int ENC_DEFAULT = 0; constexpr int ENC_FIXED = 1; constexpr int ENC_ZIGZAG = 2; +// Maximum nesting depth for nested messages +constexpr int MAX_NESTING_DEPTH = 10; + +/** + * Descriptor for a field in a nested protobuf schema. + * Used to represent flattened schema with parent-child relationships. + */ +struct nested_field_descriptor { + int field_number; // Protobuf field number + int parent_idx; // Index of parent field in schema (-1 for top-level) + int depth; // Nesting depth (0 for top-level) + int wire_type; // Expected wire type + cudf::type_id output_type; // Output cudf type + int encoding; // Encoding type (ENC_DEFAULT, ENC_FIXED, ENC_ZIGZAG) + bool is_repeated; // Whether this field is repeated (array) + bool is_required; // Whether this field is required (proto2) + bool has_default_value; // Whether this field has a default value +}; + /** * Decode protobuf messages (one message per row) from a LIST column into a STRUCT * column. @@ -102,4 +121,35 @@ std::unique_ptr decode_protobuf_to_struct( std::vector> const& enum_valid_values, bool fail_on_errors); +/** + * Decode protobuf messages with support for nested messages and repeated fields. + * + * This uses a multi-pass approach: + * - Pass 1: Scan all messages, count nested elements and repeated field occurrences + * - Pass 2: Prefix sum to compute output offsets for arrays and nested structs + * - Pass 3: Extract data using pre-computed offsets + * - Pass 4: Build nested column structure + * + * @param binary_input LIST column, each row is one protobuf message + * @param schema Flattened schema with parent-child relationships + * @param schema_output_types Output types for each field in schema (cudf types) + * @param default_ints Default values for int/long/enum fields + * @param default_floats Default values for float/double fields + * @param default_bools Default values for bool fields + * @param default_strings Default values for string/bytes fields + * @param enum_valid_values Valid enum values for each field (empty if not enum) + * @param fail_on_errors Whether to throw on malformed data + * @return STRUCT column with nested structure + */ +std::unique_ptr decode_nested_protobuf_to_struct( + cudf::column_view const& binary_input, + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + bool fail_on_errors); + } // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index e88064be0a..170be5e5c1 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -336,4 +336,121 @@ private static native long decodeToStruct(long binaryInputView, byte[][] defaultStrings, int[][] enumValidValues, boolean failOnErrors); + + // Wire type constants for nested schema + public static final int WT_VARINT = 0; + public static final int WT_64BIT = 1; + public static final int WT_LEN = 2; + public static final int WT_32BIT = 5; + + /** + * Decode protobuf messages with support for nested messages and repeated fields. + * + * This method uses a flattened schema representation where nested fields have parent indices + * pointing to their containing message field. + * + * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. + * @param fieldNumbers Protobuf field numbers for all fields in the flattened schema. + * @param parentIndices Parent field index for each field (-1 for top-level fields). + * @param depthLevels Nesting depth for each field (0 for top-level). + * @param wireTypes Expected wire type for each field (WT_VARINT, WT_64BIT, WT_LEN, WT_32BIT). + * @param outputTypeIds cudf native type ids for output columns. + * @param encodings Encoding info for each field (0=default, 1=fixed, 2=zigzag). + * @param isRepeated Whether each field is a repeated field (array). + * @param isRequired Whether each field is required (proto2). + * @param hasDefaultValue Whether each field has a default value. + * @param defaultInts Default values for int/long/enum fields. + * @param defaultFloats Default values for float/double fields. + * @param defaultBools Default values for bool fields. + * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes. + * @param enumValidValues Valid enum values for each field (null if not an enum). + * @param failOnErrors if true, throw an exception on malformed protobuf messages. + * @return a cudf STRUCT column with nested structure. + */ + public static ColumnVector decodeNestedToStruct(ColumnView binaryInput, + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors) { + // Parameter validation + if (fieldNumbers == null || parentIndices == null || depthLevels == null || + wireTypes == null || outputTypeIds == null || encodings == null || + isRepeated == null || isRequired == null || hasDefaultValue == null || + defaultInts == null || defaultFloats == null || defaultBools == null || + defaultStrings == null || enumValidValues == null) { + throw new IllegalArgumentException("Arrays must be non-null"); + } + + int numFields = fieldNumbers.length; + if (parentIndices.length != numFields || + depthLevels.length != numFields || + wireTypes.length != numFields || + outputTypeIds.length != numFields || + encodings.length != numFields || + isRepeated.length != numFields || + isRequired.length != numFields || + hasDefaultValue.length != numFields || + defaultInts.length != numFields || + defaultFloats.length != numFields || + defaultBools.length != numFields || + defaultStrings.length != numFields || + enumValidValues.length != numFields) { + throw new IllegalArgumentException("All arrays must have the same length"); + } + + // Validate field numbers are positive + for (int i = 0; i < fieldNumbers.length; i++) { + if (fieldNumbers[i] <= 0) { + throw new IllegalArgumentException( + "Invalid field number at index " + i + ": " + fieldNumbers[i] + + " (field numbers must be positive)"); + } + } + + // Validate encoding values + for (int i = 0; i < encodings.length; i++) { + int enc = encodings[i]; + if (enc < ENC_DEFAULT || enc > ENC_ZIGZAG) { + throw new IllegalArgumentException( + "Invalid encoding value at index " + i + ": " + enc + + " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); + } + } + + long handle = decodeNestedToStruct(binaryInput.getNativeView(), + fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, + isRepeated, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, failOnErrors); + return new ColumnVector(handle); + } + + private static native long decodeNestedToStruct(long binaryInputView, + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index f7450cda82..044812f128 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -26,6 +26,8 @@ import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -1615,26 +1617,46 @@ void testDefaultStringMultipleRows() { } // ============================================================================ - // Tests for Features Not Yet Implemented (Disabled) + // Tests for Nested and Repeated Fields (Phase 1-3 Implementation) // ============================================================================ - @Disabled("Unpacked repeated fields not yet implemented") @Test void testUnpackedRepeatedInt32() { // Unpacked repeated: same field number appears multiple times + // message TestMsg { repeated int32 ids = 1; } Byte[] row = concat( box(tag(1, WT_VARINT)), box(encodeVarint(1)), box(tag(1, WT_VARINT)), box(encodeVarint(2)), box(tag(1, WT_VARINT)), box(encodeVarint(3))); - // Expected: ARRAY with values [1, 2, 3] - // (Currently we implement "last one wins" semantics for scalars) try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { - // TODO: implement unpacked repeated field decoding + // Use the new nested API for repeated fields + // Field: ids (field_number=1, parent=-1, depth=0, wire_type=VARINT, type=INT32, repeated=true) + try (ColumnVector result = Protobuf.decodeNestedToStruct( + input.getColumn(0), + new int[]{1}, // fieldNumbers + new int[]{-1}, // parentIndices (-1 = top level) + new int[]{0}, // depthLevels + new int[]{Protobuf.WT_VARINT}, // wireTypes + new int[]{DType.INT32.getTypeId().getNativeId()}, // outputTypeIds (element type) + new int[]{Protobuf.ENC_DEFAULT}, // encodings + new boolean[]{true}, // isRepeated + new boolean[]{false}, // isRequired + new boolean[]{false}, // hasDefaultValue + new long[]{0}, // defaultInts + new double[]{0.0}, // defaultFloats + new boolean[]{false}, // defaultBools + new byte[][]{null}, // defaultStrings + new int[][]{null}, // enumValidValues + false)) { // failOnErrors + // Result should be STRUCT> + // The list should contain [1, 2, 3] + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } } } - @Disabled("Nested messages not yet implemented") @Test void testNestedMessage() { // message Inner { int32 x = 1; } @@ -1647,23 +1669,29 @@ void testNestedMessage() { innerMessage); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { - // TODO: implement nested message decoding - // Expected: STRUCT> - } - } - - @Disabled("Large field numbers not tested with current API") - @Test - void testLargeFieldNumber() { - // Field numbers can be up to 2^29 - 1 = 536870911 - int largeFieldNum = 536870911; - Byte[] row = concat( - box(tag(largeFieldNum, WT_VARINT)), - box(encodeVarint(42))); - - try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { - // Current API uses int[] for field numbers, should work - // But need to verify kernel handles large field numbers correctly + // Flattened schema: + // [0] inner: STRUCT, field_number=1, parent=-1, depth=0 + // [1] inner.x: INT32, field_number=1, parent=0, depth=1 + try (ColumnVector result = Protobuf.decodeNestedToStruct( + input.getColumn(0), + new int[]{1, 1}, // fieldNumbers + new int[]{-1, 0}, // parentIndices + new int[]{0, 1}, // depthLevels + new int[]{Protobuf.WT_LEN, Protobuf.WT_VARINT}, // wireTypes + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // isRepeated + new boolean[]{false, false}, // isRequired + new boolean[]{false, false}, // hasDefaultValue + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } } } From 5f89e6060f9c31cf77a072a9d36a881a225b0cdf Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Feb 2026 21:22:42 +0800 Subject: [PATCH 016/107] performance optimization Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 540 ++++++++++++++++++++--------------- 1 file changed, 307 insertions(+), 233 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index e887e9b017..0908527e76 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -1493,8 +1493,65 @@ __global__ void extract_repeated_msg_child_fixed_kernel( valid[occ_idx] = true; } +/** + * Kernel to extract string data from repeated message child fields. + * Copies all strings in parallel on the GPU instead of per-string host copies. + */ +__global__ void extract_repeated_msg_child_strings_kernel( + uint8_t const* message_data, + int32_t const* msg_row_offsets, + field_location const* msg_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + int32_t const* string_offsets, // Output offsets (exclusive scan of lengths) + char* output_chars, + bool* valid, + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; + + if (field_loc.offset < 0 || field_loc.length == 0) { + valid[idx] = false; + return; + } + + valid[idx] = true; + + int32_t row_offset = msg_row_offsets[idx]; + int32_t msg_offset = msg_locs[idx].offset; + uint8_t const* str_src = message_data + row_offset + msg_offset + field_loc.offset; + char* str_dst = output_chars + string_offsets[idx]; + + // Copy string data + for (int i = 0; i < field_loc.length; i++) { + str_dst[i] = static_cast(str_src[i]); + } +} + +/** + * Kernel to compute string lengths from child field locations. + */ +__global__ void compute_string_lengths_kernel( + field_location const* child_locs, + int child_idx, + int num_child_fields, + int32_t* lengths, + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& loc = child_locs[idx * num_child_fields + child_idx]; + lengths[idx] = (loc.offset >= 0) ? loc.length : 0; +} + /** * Helper to build string column for repeated message child fields. + * Uses GPU kernels for parallel string extraction (critical performance fix!). */ inline std::unique_ptr build_repeated_msg_child_string_column( uint8_t const* message_data, @@ -1512,86 +1569,148 @@ inline std::unique_ptr build_repeated_msg_child_string_column( return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); } - // Get string lengths from child_locs - std::vector h_child_locs(total_count * num_child_fields); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_child_locs.data(), d_child_locs.data(), - h_child_locs.size() * sizeof(field_location), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); + auto const threads = 256; + auto const blocks = (total_count + threads - 1) / threads; + + // Compute string lengths on GPU + rmm::device_uvector d_lengths(total_count, stream, mr); + compute_string_lengths_kernel<<>>( + d_child_locs.data(), child_idx, num_child_fields, d_lengths.data(), total_count); - std::vector h_lengths(total_count); + // Compute offsets via exclusive scan + rmm::device_uvector d_str_offsets(total_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_lengths.begin(), d_lengths.end(), + d_str_offsets.begin(), 0); + + // Get total chars count int32_t total_chars = 0; - for (int i = 0; i < total_count; i++) { - auto const& loc = h_child_locs[i * num_child_fields + child_idx]; - if (loc.offset >= 0) { - h_lengths[i] = loc.length; - total_chars += loc.length; - } else { - h_lengths[i] = 0; - } - } + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, d_str_offsets.data() + total_count - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, d_lengths.data() + total_count - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + + // Set final offset + CUDF_CUDA_TRY(cudaMemcpyAsync(d_str_offsets.data() + total_count, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - // Build string offsets - rmm::device_uvector str_offsets(total_count + 1, stream, mr); - std::vector h_offsets(total_count + 1); - h_offsets[0] = 0; - for (int i = 0; i < total_count; i++) { - h_offsets[i + 1] = h_offsets[i] + h_lengths[i]; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data(), h_offsets.data(), - (total_count + 1) * sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + // Allocate output chars and validity + rmm::device_uvector d_chars(total_chars, stream, mr); + rmm::device_uvector d_valid(total_count, stream, mr); - // Copy string data - rmm::device_uvector chars(total_chars, stream, mr); + // Extract all strings in parallel on GPU (critical performance fix!) if (total_chars > 0) { - std::vector h_msg_locs(total_count); - std::vector h_row_offsets(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_msg_locs.data(), d_msg_locs.data(), - total_count * sizeof(field_location), - cudaMemcpyDeviceToHost, stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_row_offsets.data(), d_msg_row_offsets.data(), - total_count * sizeof(int32_t), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - - // Copy each string on host (not ideal but works) - std::vector h_chars(total_chars); - int char_idx = 0; - for (int i = 0; i < total_count; i++) { - auto const& field_loc = h_child_locs[i * num_child_fields + child_idx]; - if (field_loc.offset >= 0 && field_loc.length > 0) { - int32_t row_offset = h_row_offsets[i]; - int32_t msg_offset = h_msg_locs[i].offset; - uint8_t const* str_ptr = message_data + row_offset + msg_offset + field_loc.offset; - // Need to copy from device - use cudaMemcpy - CUDF_CUDA_TRY(cudaMemcpy(h_chars.data() + char_idx, str_ptr, - field_loc.length, cudaMemcpyDeviceToHost)); - char_idx += field_loc.length; - } - } - CUDF_CUDA_TRY(cudaMemcpyAsync(chars.data(), h_chars.data(), - total_chars, cudaMemcpyHostToDevice, stream.value())); + extract_repeated_msg_child_strings_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), child_idx, num_child_fields, + d_str_offsets.data(), d_chars.data(), d_valid.data(), total_count); + } else { + // No strings, just set validity + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.begin(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + return child_locs[idx * ncf + ci].offset >= 0; + }); } - // Build validity mask - rmm::device_uvector valid(total_count, stream, mr); - std::vector h_valid(total_count); - for (int i = 0; i < total_count; i++) { - h_valid[i] = (h_child_locs[i * num_child_fields + child_idx].offset >= 0) ? 1 : 0; - } - rmm::device_uvector d_valid_u8(total_count, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_u8.data(), h_valid.data(), - total_count * sizeof(uint8_t), - cudaMemcpyHostToDevice, stream.value())); - - auto [mask, null_count] = make_null_mask_from_valid(d_valid_u8, stream, mr); + auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); auto str_offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), rmm::device_buffer{}, 0); - return cudf::make_strings_column(total_count, std::move(str_offsets_col), chars.release(), null_count, std::move(mask)); + cudf::data_type{cudf::type_id::INT32}, total_count + 1, d_str_offsets.release(), rmm::device_buffer{}, 0); + return cudf::make_strings_column(total_count, std::move(str_offsets_col), d_chars.release(), null_count, std::move(mask)); } +/** + * Kernel to compute nested struct locations from child field locations. + * Replaces host-side loop that was copying data D->H, processing, then H->D. + * This is a critical performance optimization. + */ +__global__ void compute_nested_struct_locations_kernel( + field_location const* child_locs, // Child field locations from parent scan + field_location const* msg_locs, // Parent message locations + int32_t const* msg_row_offsets, // Parent message row offsets + int child_idx, // Which child field is the nested struct + int num_child_fields, // Total number of child fields per occurrence + field_location* nested_locs, // Output: nested struct locations + int32_t* nested_row_offsets, // Output: nested struct row offsets + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + // Get the nested struct location from child_locs + nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; + // Compute absolute row offset = msg_row_offset + msg_offset + nested_row_offsets[idx] = msg_row_offsets[idx] + msg_locs[idx].offset; +} + +/** + * Kernel to compute absolute grandchild parent locations from parent and child locations. + * Computes: gc_parent_abs[i] = parent[i].offset + child[i * ncf + ci].offset + * This replaces host-side loop with D->H->D copy pattern. + */ +__global__ void compute_grandchild_parent_locations_kernel( + field_location const* parent_locs, // Parent locations (row count) + field_location const* child_locs, // Child locations (row * num_child_fields) + int child_idx, // Which child field + int num_child_fields, // Total child fields per row + field_location* gc_parent_abs, // Output: absolute grandchild parent locations + int num_rows) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + + auto const& parent_loc = parent_locs[row]; + auto const& child_loc = child_locs[row * num_child_fields + child_idx]; + + if (parent_loc.offset >= 0 && child_loc.offset >= 0) { + // Absolute offset = parent offset + child's relative offset + gc_parent_abs[row].offset = parent_loc.offset + child_loc.offset; + gc_parent_abs[row].length = child_loc.length; + } else { + gc_parent_abs[row] = {-1, 0}; + } +} + +/** + * Kernel to compute message locations and row offsets from repeated occurrences. + * Replaces host-side loop that processed occurrences. + */ +__global__ void compute_msg_locations_from_occurrences_kernel( + repeated_occurrence const* occurrences, // Repeated field occurrences + cudf::size_type const* list_offsets, // List offsets for rows + cudf::size_type base_offset, // Base offset to subtract + field_location* msg_locs, // Output: message locations + int32_t* msg_row_offsets, // Output: message row offsets + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + msg_row_offsets[idx] = static_cast(list_offsets[occ.row_idx] - base_offset); + msg_locs[idx] = {occ.offset, occ.length}; +} + +/** + * Functor to extract count from repeated_field_info with strided access. + * Used for extracting counts for a specific repeated field from 2D array. + */ +struct extract_strided_count { + repeated_field_info const* info; + int field_idx; + int num_fields; + + __device__ int32_t operator()(int row) const { + return info[row * num_fields + field_idx].count; + } +}; + /** * Extract varint from nested message locations. */ @@ -2898,26 +3017,25 @@ std::unique_ptr build_repeated_scalar_column( cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from counts - rmm::device_uvector counts(num_rows, stream, mr); - std::vector h_counts(num_rows); - for (int i = 0; i < num_rows; i++) { - h_counts[i] = h_repeated_info[i].count; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), h_counts.data(), num_rows * sizeof(int32_t), + // Build list offsets from counts entirely on GPU (performance fix!) + // Copy h_repeated_info to device and use thrust::transform to extract counts + rmm::device_uvector d_rep_info(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), h_repeated_info.data(), + num_rows * sizeof(repeated_field_info), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector counts(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.begin(), d_rep_info.end(), + counts.begin(), + [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - int32_t last_offset_h = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_offset_h, list_offs.data() + num_rows - 1, sizeof(int32_t), - cudaMemcpyDeviceToHost, stream.value())); - int32_t last_count_h = h_counts[num_rows - 1]; - stream.synchronize(); - last_offset_h += last_count_h; - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &last_offset_h, sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + // Set last offset = total_count + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); // Extract values rmm::device_uvector values(total_count, stream, mr); @@ -3018,26 +3136,25 @@ std::unique_ptr build_repeated_string_column( cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from counts - rmm::device_uvector counts(num_rows, stream, mr); - std::vector h_counts(num_rows); - for (int i = 0; i < num_rows; i++) { - h_counts[i] = h_repeated_info[i].count; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), h_counts.data(), num_rows * sizeof(int32_t), + // Build list offsets from counts entirely on GPU (performance fix!) + // Copy h_repeated_info to device and use thrust::transform to extract counts + rmm::device_uvector d_rep_info(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), h_repeated_info.data(), + num_rows * sizeof(repeated_field_info), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector counts(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.begin(), d_rep_info.end(), + counts.begin(), + [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - - int32_t last_offset_h = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_offset_h, list_offs.data() + num_rows - 1, sizeof(int32_t), - cudaMemcpyDeviceToHost, stream.value())); - int32_t last_count_h = h_counts[num_rows - 1]; - stream.synchronize(); - last_offset_h += last_count_h; - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &last_offset_h, sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + + // Set last offset = total_count + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); @@ -3168,32 +3285,25 @@ std::unique_ptr build_repeated_struct_column( cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from counts (for the outer LIST column) - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - std::vector h_counts(num_rows); - for (int i = 0; i < num_rows; i++) { - h_counts[i] = h_repeated_info[i].count; - } - rmm::device_uvector counts(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), h_counts.data(), num_rows * sizeof(int32_t), + // Build list offsets from counts entirely on GPU (performance fix!) + // Copy repeated_info to device and use thrust::transform to extract counts + rmm::device_uvector d_rep_info(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), h_repeated_info.data(), + num_rows * sizeof(repeated_field_info), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector counts(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.begin(), d_rep_info.end(), + counts.begin(), + [] __device__(repeated_field_info const& info) { return info.count; }); + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - int32_t last_offset_h = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_offset_h, list_offs.data() + num_rows - 1, sizeof(int32_t), - cudaMemcpyDeviceToHost, stream.value())); - int32_t last_count_h = h_counts[num_rows - 1]; - stream.synchronize(); - last_offset_h += last_count_h; - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &last_offset_h, sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); - - // Copy occurrences to host for processing - std::vector h_occurrences(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_occurrences.data(), d_occurrences.data(), - total_count * sizeof(repeated_occurrence), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); + // Set last offset = total_count (already computed on caller side) + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); // Build child field descriptors for scanning within each message occurrence std::vector h_child_descs(num_child_fields); @@ -3208,29 +3318,17 @@ std::unique_ptr build_repeated_struct_column( cudaMemcpyHostToDevice, stream.value())); // For each occurrence, we need to scan for child fields - // Create "virtual" parent locations from the occurrences - // Each occurrence becomes a "parent" message for child field scanning - std::vector h_msg_locs(total_count); - std::vector h_msg_row_offsets(total_count); - for (int i = 0; i < total_count; i++) { - auto const& occ = h_occurrences[i]; - // Get the row's start offset in the binary column - cudf::size_type row_offset; - CUDF_CUDA_TRY(cudaMemcpyAsync(&row_offset, list_offsets + occ.row_idx, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - h_msg_row_offsets[i] = static_cast(row_offset - base_offset); - h_msg_locs[i] = {occ.offset, occ.length}; - } - + // Create "virtual" parent locations from the occurrences using GPU kernel + // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) rmm::device_uvector d_msg_locs(total_count, stream, mr); rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_msg_locs.data(), h_msg_locs.data(), - total_count * sizeof(field_location), - cudaMemcpyHostToDevice, stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_msg_row_offsets.data(), h_msg_row_offsets.data(), - total_count * sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + { + auto const occ_threads = 256; + auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; + compute_msg_locations_from_occurrences_kernel<<>>( + d_occurrences.data(), list_offsets, base_offset, + d_msg_locs.data(), d_msg_row_offsets.data(), total_count); + } // Scan for child fields within each message occurrence rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); @@ -3246,12 +3344,10 @@ std::unique_ptr build_repeated_struct_column( message_data, d_msg_row_offsets.data(), d_msg_locs.data(), total_count, d_child_descs.data(), num_child_fields, d_child_locs.data(), d_error.data()); - // Copy child locations to host - std::vector h_child_locs(total_count * num_child_fields); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_child_locs.data(), d_child_locs.data(), - h_child_locs.size() * sizeof(field_location), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); + // Note: We no longer need to copy child_locs to host because: + // 1. All scalar extraction kernels access d_child_locs directly on device + // 2. String extraction uses GPU kernels + // 3. Nested struct locations are computed on GPU via compute_nested_struct_locations_kernel // Extract child field values - build one column per child field std::vector> struct_children; @@ -3386,25 +3482,13 @@ std::unique_ptr build_repeated_struct_column( num_grandchildren * sizeof(field_descriptor), cudaMemcpyHostToDevice, stream.value())); - // Create nested struct locations from child_locs - // Each occurrence's nested struct is at child_locs[occ * num_child_fields + ci] - std::vector h_nested_locs(total_count); - std::vector h_nested_row_offsets(total_count); - for (int occ = 0; occ < total_count; occ++) { - auto const& nested_loc = h_child_locs[occ * num_child_fields + ci]; - auto const& msg_loc = h_msg_locs[occ]; - h_nested_row_offsets[occ] = h_msg_row_offsets[occ] + msg_loc.offset; - h_nested_locs[occ] = nested_loc; - } - + // Create nested struct locations from child_locs using GPU kernel + // This eliminates the D->H->D copy pattern (critical performance optimization) rmm::device_uvector d_nested_locs(total_count, stream, mr); rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_locs.data(), h_nested_locs.data(), - total_count * sizeof(field_location), - cudaMemcpyHostToDevice, stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_row_offsets.data(), h_nested_row_offsets.data(), - total_count * sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + compute_nested_struct_locations_kernel<<>>( + d_child_locs.data(), d_msg_locs.data(), d_msg_row_offsets.data(), + ci, num_child_fields, d_nested_locs.data(), d_nested_row_offsets.data(), total_count); // Scan for grandchild fields rmm::device_uvector d_gc_locs(total_count * num_grandchildren, stream, mr); @@ -3849,25 +3933,32 @@ std::unique_ptr decode_nested_protobuf_to_struct( int schema_idx = repeated_field_indices[ri]; auto element_type = schema_output_types[schema_idx]; - // Get per-row info for this repeated field + // Get per-row counts for this repeated field entirely on GPU (performance fix!) + rmm::device_uvector d_field_counts(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + d_field_counts.begin(), + extract_strided_count{d_repeated_info.data(), ri, num_repeated}); + + int total_count = thrust::reduce(rmm::exec_policy(stream), + d_field_counts.begin(), d_field_counts.end(), 0); + + // Still need host-side field_info for build_repeated_scalar_column std::vector field_info(num_rows); - int total_count = 0; for (int row = 0; row < num_rows; row++) { field_info[row] = h_repeated_info[row * num_repeated + ri]; - total_count += field_info[row].count; } if (total_count > 0) { - // Build offsets for occurrence scanning + // Build offsets for occurrence scanning on GPU (performance fix!) rmm::device_uvector d_occ_offsets(num_rows + 1, stream, mr); - std::vector h_occ_offsets(num_rows + 1); - h_occ_offsets[0] = 0; - for (int row = 0; row < num_rows; row++) { - h_occ_offsets[row + 1] = h_occ_offsets[row] + field_info[row].count; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data(), h_occ_offsets.data(), - (num_rows + 1) * sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_field_counts.begin(), d_field_counts.end(), + d_occ_offsets.begin(), 0); + // Set last element + CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data() + num_rows, &total_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); // Scan for all occurrences rmm::device_uvector d_occurrences(total_count, stream, mr); @@ -4075,15 +4166,16 @@ std::unique_ptr decode_nested_protobuf_to_struct( message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), d_error.data()); - std::vector h_rep_info(num_rows); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_rep_info.data(), d_rep_info.data(), - num_rows * sizeof(repeated_field_info), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); + // Compute total_rep_count on GPU using thrust::reduce (performance fix!) + // Extract counts from repeated_field_info on device + rmm::device_uvector d_rep_counts(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.begin(), d_rep_info.end(), + d_rep_counts.begin(), + [] __device__(repeated_field_info const& info) { return info.count; }); - int total_rep_count = 0; - for (int row = 0; row < num_rows; row++) { - total_rep_count += h_rep_info[row].count; - } + int total_rep_count = thrust::reduce(rmm::exec_policy(stream), + d_rep_counts.begin(), d_rep_counts.end(), 0); if (total_rep_count == 0) { rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); @@ -4100,14 +4192,14 @@ std::unique_ptr decode_nested_protobuf_to_struct( d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), d_rep_occs.data(), d_error.data()); + // Compute list offsets on GPU using exclusive_scan (performance fix!) rmm::device_uvector list_offs(num_rows + 1, stream, mr); - std::vector h_list_offs(num_rows + 1); - h_list_offs[0] = 0; - for (int row = 0; row < num_rows; row++) { - h_list_offs[row + 1] = h_list_offs[row] + h_rep_info[row].count; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data(), h_list_offs.data(), - (num_rows + 1) * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_rep_counts.begin(), d_rep_counts.end(), + list_offs.begin(), 0); + // Set last element + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_rep_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); std::unique_ptr child_values; if (elem_type_id == cudf::type_id::INT32) { @@ -4125,22 +4217,25 @@ std::unique_ptr decode_nested_protobuf_to_struct( child_values = std::make_unique( cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); } else if (elem_type_id == cudf::type_id::STRING) { - std::vector h_rep_occs(total_rep_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_rep_occs.data(), d_rep_occs.data(), - total_rep_count * sizeof(repeated_occurrence), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); + // Compute string offsets on GPU using thrust (performance fix!) + // Extract lengths from occurrences on device + rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.begin(), d_rep_occs.end(), + d_str_lengths.begin(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); - int32_t total_chars = 0; - std::vector h_str_offs(total_rep_count + 1); - h_str_offs[0] = 0; - for (int i = 0; i < total_rep_count; i++) { - h_str_offs[i + 1] = h_str_offs[i] + h_rep_occs[i].length; - total_chars += h_rep_occs[i].length; - } + // Compute total chars and offsets + int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), + d_str_lengths.begin(), d_str_lengths.end(), 0); rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data(), h_str_offs.data(), - (total_rep_count + 1) * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_str_lengths.begin(), d_str_lengths.end(), + str_offs.begin(), 0); + // Set last element + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { @@ -4322,36 +4417,15 @@ std::unique_ptr decode_nested_protobuf_to_struct( } int num_gc = static_cast(gc_indices.size()); - // Get child struct locations for grandchild scanning + // Get child struct locations for grandchild scanning using GPU kernel // IMPORTANT: Need to compute ABSOLUTE offsets (relative to row start) // d_child_locations contains offsets relative to parent message (Middle) // We need: child_offset_in_row = parent_offset_in_row + child_offset_in_parent - std::vector h_parent_locs(num_rows); - std::vector h_child_locs_rel(num_rows); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_parent_locs.data(), d_parent_locs.data(), - num_rows * sizeof(field_location), cudaMemcpyDeviceToHost, stream.value())); - for (int row = 0; row < num_rows; row++) { - CUDF_CUDA_TRY(cudaMemcpyAsync(&h_child_locs_rel[row], - d_child_locations.data() + row * num_child_fields + ci, - sizeof(field_location), cudaMemcpyDeviceToHost, stream.value())); - } - stream.synchronize(); - - // Compute absolute offsets - std::vector h_gc_parent_abs(num_rows); - for (int row = 0; row < num_rows; row++) { - if (h_parent_locs[row].offset >= 0 && h_child_locs_rel[row].offset >= 0) { - // Absolute offset = parent offset + child's relative offset - h_gc_parent_abs[row].offset = h_parent_locs[row].offset + h_child_locs_rel[row].offset; - h_gc_parent_abs[row].length = h_child_locs_rel[row].length; - } else { - h_gc_parent_abs[row] = {-1, 0}; - } - } - + // This is computed entirely on GPU to avoid D->H->D copy pattern (performance fix!) rmm::device_uvector d_gc_parent(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_gc_parent.data(), h_gc_parent_abs.data(), - num_rows * sizeof(field_location), cudaMemcpyHostToDevice, stream.value())); + compute_grandchild_parent_locations_kernel<<>>( + d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, + d_gc_parent.data(), num_rows); // Build grandchild field descriptors std::vector h_gc_descs(num_gc); From e311b4bcb6d6b46c5b6eee91c34592dfa923ed87 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Feb 2026 17:11:18 +0800 Subject: [PATCH 017/107] single-pass-kernel, with debug log, met unbreakable wall Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 1379 +++++++++++++++++++++++++++++++++- 1 file changed, 1377 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 0908527e76..a762ffc1cd 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -32,12 +32,20 @@ #include #include +#include +#include +#include #include +#include #include +#include +#include -#include -#include +#include #include +#include +#include +#include #include namespace { @@ -101,6 +109,61 @@ struct device_nested_field_descriptor { bool has_default_value; }; +// ============================================================================ +// Single-pass decoder data structures +// ============================================================================ + +/// Maximum nesting depth for single-pass decoder +constexpr int SP_MAX_DEPTH = 10; + +/// Maximum number of counted columns (repeated fields at all depths) +constexpr int SP_MAX_COUNTED = 128; + +/// Maximum number of output columns +constexpr int SP_MAX_OUTPUT_COLS = 512; + +/// Message type descriptor: groups fields belonging to the same protobuf message +struct sp_msg_type { + int first_field_idx; // Start index in the global sp_field_entry array + int num_fields; // Number of direct child fields + int lookup_offset; // Offset into d_field_lookup table (-1 if not using lookup) + int max_field_number; // Max field number + 1 (size of lookup region) +}; + +/// Field entry for single-pass decoder (device-side, sorted by field_number per msg type) +struct sp_field_entry { + int field_number; // Protobuf field number + int wire_type; // Expected wire type + int output_type_id; // cudf type_id cast to int (-1 for struct containers) + int encoding; // ENC_DEFAULT / ENC_FIXED / ENC_ZIGZAG + int child_msg_type; // For nested messages: index into sp_msg_type (-1 otherwise) + int col_idx; // Index into output column descriptors (-1 for containers) + int count_idx; // For repeated fields: index into per-row count array (-1 if not) + bool is_repeated; // Whether this field is repeated + bool has_default; // Whether this field has a default value + int64_t default_int; // Default value for int/long/bool + double default_float; // Default value for float/double +}; + +/// Stack entry for nested message parsing within a kernel thread +struct sp_stack_entry { + int parent_end_offset; // End offset of parent message (relative to row start) + int msg_type_idx; // Saved message type index + int write_base; // Saved write base for non-repeated children +}; + +/// Output column descriptor (device-side, used during Pass 2) +struct sp_col_desc { + void* data; // Typed data buffer (or string_index_pair* for strings) + bool* validity; // Validity buffer (one bool per element) +}; + +/// Pair for zero-copy string references (device-side) +struct sp_string_pair { + char const* ptr; // Pointer into message data (null if not found) + int32_t length; // String length in bytes (0 if not found) +}; + // ============================================================================ // Device helper functions // ============================================================================ @@ -3590,6 +3653,1298 @@ std::unique_ptr build_repeated_struct_column( return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); } +// ============================================================================ +// Single-Pass Decoder Implementation +// ============================================================================ + +/** + * O(1) field lookup using direct-mapped table. + * d_field_lookup[msg_type.lookup_offset + field_number] = field_entry index, or -1. + */ +__device__ inline int sp_lookup_field( + sp_msg_type const* msg_types, + sp_field_entry const* /*field_entries*/, + int const* d_field_lookup, + int msg_type_idx, + int field_number) +{ + auto const& mt = msg_types[msg_type_idx]; + if (field_number < 0 || field_number >= mt.max_field_number) return -1; + return d_field_lookup[mt.lookup_offset + field_number]; +} + +/** + * Write an extracted scalar value to the output column. + * cur is advanced past the consumed bytes. + */ +__device__ inline void sp_write_scalar( + uint8_t const*& cur, + uint8_t const* end, + sp_field_entry const& fe, + sp_col_desc* col_descs, + int write_pos) +{ + if (fe.col_idx < 0) return; + auto& cd = col_descs[fe.col_idx]; + + if (fe.wire_type == WT_VARINT) { + uint64_t val; int vb; + if (!read_varint(cur, end, val, vb)) return; + cur += vb; + if (fe.encoding == spark_rapids_jni::ENC_ZIGZAG) { + val = (val >> 1) ^ (-(val & 1)); + } + int tid = fe.output_type_id; + if (tid == static_cast(cudf::type_id::BOOL8)) + reinterpret_cast(cd.data)[write_pos] = val ? 1 : 0; + else if (tid == static_cast(cudf::type_id::INT32)) + reinterpret_cast(cd.data)[write_pos] = static_cast(val); + else if (tid == static_cast(cudf::type_id::UINT32)) + reinterpret_cast(cd.data)[write_pos] = static_cast(val); + else if (tid == static_cast(cudf::type_id::INT64)) + reinterpret_cast(cd.data)[write_pos] = static_cast(val); + else if (tid == static_cast(cudf::type_id::UINT64)) + reinterpret_cast(cd.data)[write_pos] = val; + cd.validity[write_pos] = true; + + } else if (fe.wire_type == WT_32BIT) { + if (end - cur < 4) return; + uint32_t raw = load_le(cur); + cur += 4; + int tid = fe.output_type_id; + if (tid == static_cast(cudf::type_id::FLOAT32)) { + float f; memcpy(&f, &raw, 4); + reinterpret_cast(cd.data)[write_pos] = f; + } else { + reinterpret_cast(cd.data)[write_pos] = static_cast(raw); + } + cd.validity[write_pos] = true; + + } else if (fe.wire_type == WT_64BIT) { + if (end - cur < 8) return; + uint64_t raw = load_le(cur); + cur += 8; + int tid = fe.output_type_id; + if (tid == static_cast(cudf::type_id::FLOAT64)) { + double d; memcpy(&d, &raw, 8); + reinterpret_cast(cd.data)[write_pos] = d; + } else { + reinterpret_cast(cd.data)[write_pos] = static_cast(raw); + } + cd.validity[write_pos] = true; + + } else if (fe.wire_type == WT_LEN) { + // String / bytes + uint64_t len; int lb; + if (!read_varint(cur, end, len, lb)) return; + auto* pairs = reinterpret_cast(cd.data); + pairs[write_pos].ptr = reinterpret_cast(cur + lb); + pairs[write_pos].length = static_cast(len); + cd.validity[write_pos] = true; + cur += lb + static_cast(len); + } +} + +/** + * Count the number of packed elements in a length-delimited blob for a given element wire type. + */ +__device__ inline int sp_count_packed( + uint8_t const* data, int data_len, int elem_wire_type) +{ + if (elem_wire_type == WT_VARINT) { + int count = 0; + uint8_t const* p = data; + uint8_t const* pe = data + data_len; + while (p < pe) { + while (p < pe && (*p & 0x80u)) p++; + if (p < pe) { p++; count++; } + } + return count; + } else if (elem_wire_type == WT_32BIT) { + return data_len / 4; + } else if (elem_wire_type == WT_64BIT) { + return data_len / 8; + } + return 0; +} + +// ============================================================================ +// Pass 1: Unified Count Kernel +// Walks each message once, counting all repeated fields at all depths. +// ============================================================================ + +__global__ void sp_unified_count_kernel( + cudf::column_device_view const d_in, + sp_msg_type const* msg_types, + sp_field_entry const* fields, + int const* d_field_lookup, + int32_t* d_counts, // [num_rows * num_count_cols] + int num_count_cols, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= d_in.size()) return; + + auto const in = cudf::detail::lists_column_device_view(d_in); + auto const& child = in.child(); + auto const base = in.offsets().element(in.offset()); + auto const start = in.offset_at(row) - base; + auto const stop = in.offset_at(row + 1) - base; + auto const* bytes = reinterpret_cast(child.data()); + + // Local counters for each counted column (repeated fields) + int32_t local_counts[SP_MAX_COUNTED]; + for (int i = 0; i < num_count_cols && i < SP_MAX_COUNTED; i++) local_counts[i] = 0; + + // Stack for nested message parsing + sp_stack_entry stack[SP_MAX_DEPTH]; + int depth = 0; + int msg_type = 0; // Root message type + + uint8_t const* cur = bytes + start; + uint8_t const* end_ptr = bytes + stop; + + while (cur < end_ptr || depth > 0) { + if (cur >= end_ptr) { + if (depth <= 0) break; + depth--; + cur = end_ptr; + end_ptr = bytes + start + stack[depth].parent_end_offset; + msg_type = stack[depth].msg_type_idx; + continue; + } + + // Read tag + uint64_t key; int kb; + if (!read_varint(cur, end_ptr, key, kb)) { atomicExch(error_flag, 1); break; } + cur += kb; + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + int fi = sp_lookup_field(msg_types, fields, d_field_lookup, msg_type, fn); + + if (fi < 0) { + // Unknown field - skip + uint8_t const* next; + if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } + cur = next; + continue; + } + + auto const& fe = fields[fi]; + + // Check for packed encoding (repeated + WT_LEN but element is not LEN) + if (fe.is_repeated && wt == WT_LEN && fe.wire_type != WT_LEN && fe.count_idx >= 0) { + uint64_t len; int lb; + if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } + int packed_len = static_cast(len); + local_counts[fe.count_idx] += sp_count_packed(cur + lb, packed_len, fe.wire_type); + cur += lb + packed_len; + continue; + } + + // Wire type mismatch - skip + if (wt != fe.wire_type) { + uint8_t const* next; + if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } + cur = next; + continue; + } + + // Nested message field + if (fe.child_msg_type >= 0 && wt == WT_LEN) { + uint64_t len; int lb; + if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } + cur += lb; + int sub_end = static_cast((cur + static_cast(len)) - (bytes + start)); + + if (fe.is_repeated && fe.count_idx >= 0) { + local_counts[fe.count_idx]++; + } + + if (depth < SP_MAX_DEPTH) { + stack[depth] = {static_cast(end_ptr - (bytes + start)), msg_type, 0}; + depth++; + end_ptr = bytes + start + sub_end; + msg_type = fe.child_msg_type; + } else { + // Max depth exceeded - skip sub-message + cur += static_cast(len); + } + continue; + } + + // Repeated non-message field + if (fe.is_repeated && fe.count_idx >= 0) { + local_counts[fe.count_idx]++; + } + + // Skip field value + uint8_t const* next; + if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } + cur = next; + } + + // Write counts to global memory + for (int i = 0; i < num_count_cols && i < SP_MAX_COUNTED; i++) { + d_counts[static_cast(row) * num_count_cols + i] = local_counts[i]; + } +} + +// ============================================================================ +// Pass 2: Unified Extract Kernel +// Walks each message once, extracting all field values at all depths. +// ============================================================================ + +__global__ void sp_unified_extract_kernel( + cudf::column_device_view const d_in, + sp_msg_type const* msg_types, + sp_field_entry const* fields, + int const* d_field_lookup, + sp_col_desc* col_descs, + int32_t const* d_row_offsets, // [num_rows * num_count_cols] - per-row write offsets + int32_t* const* d_parent_bufs, // [num_count_cols] - parent index buffers (null if not inner) + int num_count_cols, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= d_in.size()) return; + + auto const in = cudf::detail::lists_column_device_view(d_in); + auto const& child = in.child(); + auto const base = in.offsets().element(in.offset()); + auto const start = in.offset_at(row) - base; + auto const stop = in.offset_at(row + 1) - base; + auto const* bytes = reinterpret_cast(child.data()); + + // Local write counters (initialized from row offsets) + int32_t local_counter[SP_MAX_COUNTED]; + for (int i = 0; i < num_count_cols && i < SP_MAX_COUNTED; i++) { + local_counter[i] = d_row_offsets[static_cast(row) * num_count_cols + i]; + } + + sp_stack_entry stack[SP_MAX_DEPTH]; + int depth = 0; + int msg_type = 0; + int write_base = row; // Write position for non-repeated children + + uint8_t const* cur = bytes + start; + uint8_t const* end_ptr = bytes + stop; + + while (cur < end_ptr || depth > 0) { + if (cur >= end_ptr) { + if (depth <= 0) break; + depth--; + cur = end_ptr; + end_ptr = bytes + start + stack[depth].parent_end_offset; + msg_type = stack[depth].msg_type_idx; + write_base = stack[depth].write_base; + continue; + } + + // Read tag + uint64_t key; int kb; + if (!read_varint(cur, end_ptr, key, kb)) { atomicExch(error_flag, 1); break; } + cur += kb; + int fn = static_cast(key >> 3); + int wt = static_cast(key & 0x7); + + int fi = sp_lookup_field(msg_types, fields, d_field_lookup, msg_type, fn); + + if (fi < 0) { + uint8_t const* next; + if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } + cur = next; + continue; + } + + auto const& fe = fields[fi]; + + // Packed encoding for repeated scalars + if (fe.is_repeated && wt == WT_LEN && fe.wire_type != WT_LEN && fe.count_idx >= 0) { + uint64_t len; int lb; + if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } + uint8_t const* pstart = cur + lb; + uint8_t const* pend = pstart + static_cast(len); + uint8_t const* p = pstart; + + while (p < pend) { + int pos = local_counter[fe.count_idx]++; + if (d_parent_bufs && d_parent_bufs[fe.count_idx]) { + d_parent_bufs[fe.count_idx][pos] = write_base; + } + if (fe.col_idx >= 0) { + auto& cd = col_descs[fe.col_idx]; + if (fe.wire_type == WT_VARINT) { + uint64_t val; int vb; + if (!read_varint(p, pend, val, vb)) break; + p += vb; + if (fe.encoding == spark_rapids_jni::ENC_ZIGZAG) val = (val >> 1) ^ (-(val & 1)); + int tid = fe.output_type_id; + if (tid == static_cast(cudf::type_id::BOOL8)) + reinterpret_cast(cd.data)[pos] = val ? 1 : 0; + else if (tid == static_cast(cudf::type_id::INT32)) + reinterpret_cast(cd.data)[pos] = static_cast(val); + else if (tid == static_cast(cudf::type_id::UINT32)) + reinterpret_cast(cd.data)[pos] = static_cast(val); + else if (tid == static_cast(cudf::type_id::INT64)) + reinterpret_cast(cd.data)[pos] = static_cast(val); + else if (tid == static_cast(cudf::type_id::UINT64)) + reinterpret_cast(cd.data)[pos] = val; + cd.validity[pos] = true; + } else if (fe.wire_type == WT_32BIT) { + if (pend - p < 4) break; + uint32_t raw = load_le(p); p += 4; + if (fe.output_type_id == static_cast(cudf::type_id::FLOAT32)) { + float f; memcpy(&f, &raw, 4); + reinterpret_cast(cd.data)[pos] = f; + } else { + reinterpret_cast(cd.data)[pos] = static_cast(raw); + } + cd.validity[pos] = true; + } else if (fe.wire_type == WT_64BIT) { + if (pend - p < 8) break; + uint64_t raw = load_le(p); p += 8; + if (fe.output_type_id == static_cast(cudf::type_id::FLOAT64)) { + double d; memcpy(&d, &raw, 8); + reinterpret_cast(cd.data)[pos] = d; + } else { + reinterpret_cast(cd.data)[pos] = static_cast(raw); + } + cd.validity[pos] = true; + } + } + } + cur = pend; + continue; + } + + // Wire type mismatch - skip + if (wt != fe.wire_type) { + uint8_t const* next; + if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } + cur = next; + continue; + } + + // Nested message + if (fe.child_msg_type >= 0 && wt == WT_LEN) { + uint64_t len; int lb; + if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } + cur += lb; + int sub_end = static_cast((cur + static_cast(len)) - (bytes + start)); + + int new_write_base = write_base; + if (fe.is_repeated && fe.count_idx >= 0) { + int p_pos = local_counter[fe.count_idx]++; + if (d_parent_bufs && d_parent_bufs[fe.count_idx]) { + d_parent_bufs[fe.count_idx][p_pos] = write_base; + } + new_write_base = p_pos; + } + // Set struct validity if we have a col_idx + if (fe.col_idx >= 0) { + col_descs[fe.col_idx].validity[new_write_base] = true; + } + + if (depth < SP_MAX_DEPTH) { + stack[depth] = {static_cast(end_ptr - (bytes + start)), msg_type, write_base}; + depth++; + end_ptr = bytes + start + sub_end; + msg_type = fe.child_msg_type; + write_base = new_write_base; + } else { + cur += static_cast(len); + } + continue; + } + + // Non-message field: extract value + if (fe.is_repeated && fe.count_idx >= 0) { + int pos = local_counter[fe.count_idx]++; + if (d_parent_bufs && d_parent_bufs[fe.count_idx]) { + d_parent_bufs[fe.count_idx][pos] = write_base; + } + sp_write_scalar(cur, end_ptr, fe, col_descs, pos); + } else { + // Non-repeated: write at write_base (last one wins on overwrite) + sp_write_scalar(cur, end_ptr, fe, col_descs, write_base); + } + } +} + +// ============================================================================ +// Fused prefix sum + list offsets kernels (replaces per-column thrust loops) +// ============================================================================ + +/** + * Compute exclusive prefix sums for ALL count columns in a single kernel launch. + * One thread per count column - each thread serially scans its column. + * Also writes the per-column totals and builds list offsets (num_rows+1). + */ +__global__ void sp_compute_offsets_kernel( + int32_t const* d_counts, // [num_rows × num_count_cols] row-major + int32_t* d_row_offsets, // [num_rows × num_count_cols] row-major output + int32_t* d_totals, // [num_count_cols] output + int32_t** d_list_offs_ptrs, // [num_count_cols] pointers to list offset buffers (num_rows+1 each) + int num_rows, + int num_count_cols) +{ + int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c >= num_count_cols) return; + + int32_t* list_offs = d_list_offs_ptrs[c]; + int32_t sum = 0; + for (int r = 0; r < num_rows; r++) { + auto idx = static_cast(r) * num_count_cols + c; + int32_t val = d_counts[idx]; + d_row_offsets[idx] = sum; + if (list_offs) list_offs[r] = sum; + sum += val; + } + d_totals[c] = sum; + if (list_offs) list_offs[num_rows] = sum; +} + +// ============================================================================ +// Host-side helpers for single-pass decoder +// ============================================================================ + +/// Host-side column info for assembly +struct sp_host_col_info { + int schema_idx; + int col_idx; // col_idx in sp_col_desc (-1 for repeated struct containers) + int count_idx; // For repeated fields (-1 otherwise) + int parent_count_idx; // count_idx of nearest repeated ancestor (-1 for top-level) + cudf::type_id type_id; + bool is_repeated; + bool is_string; + int parent_schema_idx; // -1 for top-level +}; + +/** + * Build single-pass schema from nested_field_descriptor arrays. + * Produces message type tables, field entries, and column info. + */ +void build_single_pass_schema( + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + // Outputs: + std::vector& msg_types, + std::vector& field_entries, + std::vector& col_infos, + std::vector& field_lookup_table, + int& num_count_cols, + int& num_output_cols) +{ + int num_fields = static_cast(schema.size()); + + // Group children by parent_idx + std::map> parent_to_children; + for (int i = 0; i < num_fields; i++) { + parent_to_children[schema[i].parent_idx].push_back(i); + } + + // Assign message type indices: root first, then each struct parent + std::map parent_to_msg_type; + int msg_type_counter = 0; + parent_to_msg_type[-1] = msg_type_counter++; + + for (int i = 0; i < num_fields; i++) { + auto type_id = schema_output_types[i].id(); + if (type_id == cudf::type_id::STRUCT && parent_to_children.count(i) > 0) { + parent_to_msg_type[i] = msg_type_counter++; + } + } + + // Assign col_idx and count_idx via DFS + int col_counter = 0; + int count_counter = 0; + std::map schema_to_col_idx; + std::map schema_to_count_idx; + + std::function assign_indices = [&](int parent_idx, int parent_count_idx) { + auto it = parent_to_children.find(parent_idx); + if (it == parent_to_children.end()) return; + + for (int si : it->second) { + auto type_id = schema_output_types[si].id(); + bool is_repeated = schema[si].is_repeated; + bool is_struct = (type_id == cudf::type_id::STRUCT); + // STRING and LIST (bytes) are both length-delimited and stored as sp_string_pair + bool is_string = (type_id == cudf::type_id::STRING || type_id == cudf::type_id::LIST); + + int my_count_idx = -1; + if (is_repeated) { + my_count_idx = count_counter++; + schema_to_count_idx[si] = my_count_idx; + } + + int my_col_idx = -1; + // All non-repeated-struct fields get a col_idx for data writing. + // Non-repeated struct containers also get one for validity tracking. + if (is_struct && !is_repeated) { + my_col_idx = col_counter++; + schema_to_col_idx[si] = my_col_idx; + } else if (!is_struct) { + my_col_idx = col_counter++; + schema_to_col_idx[si] = my_col_idx; + } + // Repeated structs: no col_idx (list offsets from count, struct from children) + + sp_host_col_info info{}; + info.schema_idx = si; + info.col_idx = my_col_idx; + info.count_idx = my_count_idx; + info.parent_count_idx = parent_count_idx; + info.type_id = type_id; + info.is_repeated = is_repeated; + info.is_string = is_string; + info.parent_schema_idx = parent_idx; + col_infos.push_back(info); + + if (is_struct) { + int child_parent_count = is_repeated ? my_count_idx : parent_count_idx; + assign_indices(si, child_parent_count); + } + } + }; + + assign_indices(-1, -1); + num_count_cols = count_counter; + num_output_cols = col_counter; + + // Build sp_msg_type and sp_field_entry arrays + msg_types.resize(msg_type_counter); + for (auto& [pidx, mt_idx] : parent_to_msg_type) { + auto it = parent_to_children.find(pidx); + if (it == parent_to_children.end()) { + msg_types[mt_idx] = {static_cast(field_entries.size()), 0, -1, 0}; + continue; + } + auto children = it->second; + std::sort(children.begin(), children.end(), [&](int a, int b) { + return schema[a].field_number < schema[b].field_number; + }); + + int first_idx = static_cast(field_entries.size()); + for (int si : children) { + sp_field_entry e{}; + e.field_number = schema[si].field_number; + e.wire_type = schema[si].wire_type; + e.output_type_id = static_cast(schema_output_types[si].id()); + e.encoding = schema[si].encoding; + e.is_repeated = schema[si].is_repeated; + e.has_default = schema[si].has_default_value; + e.default_int = e.has_default ? default_ints[si] : 0; + e.default_float = e.has_default ? default_floats[si] : 0.0; + + auto type_id = schema_output_types[si].id(); + if (type_id == cudf::type_id::STRUCT) { + auto mt_it = parent_to_msg_type.find(si); + e.child_msg_type = (mt_it != parent_to_msg_type.end()) ? mt_it->second : -1; + } else { + e.child_msg_type = -1; + } + + auto col_it = schema_to_col_idx.find(si); + e.col_idx = (col_it != schema_to_col_idx.end()) ? col_it->second : -1; + auto cnt_it = schema_to_count_idx.find(si); + e.count_idx = (cnt_it != schema_to_count_idx.end()) ? cnt_it->second : -1; + + field_entries.push_back(e); + } + msg_types[mt_idx] = {first_idx, static_cast(children.size()), -1, 0}; + } + + // Build direct-mapped field lookup table for O(1) field lookup + // For each message type, allocate a region of [0..max_field_number) in the table. + // table[offset + field_number] = index into field_entries, or -1 if not found. + int lookup_offset = 0; + for (int mt = 0; mt < msg_type_counter; mt++) { + auto& mtype = msg_types[mt]; + if (mtype.num_fields == 0) { + mtype.lookup_offset = lookup_offset; + mtype.max_field_number = 1; // at least 1 to avoid zero-size + field_lookup_table.push_back(-1); + lookup_offset += 1; + continue; + } + // Find max field number in this message type + int max_fn = 0; + for (int f = mtype.first_field_idx; f < mtype.first_field_idx + mtype.num_fields; f++) { + max_fn = std::max(max_fn, field_entries[f].field_number); + } + int table_size = max_fn + 1; + mtype.lookup_offset = lookup_offset; + mtype.max_field_number = table_size; + + // Fill with -1 (not found) + int base = static_cast(field_lookup_table.size()); + field_lookup_table.resize(base + table_size, -1); + // Set entries for known fields + for (int f = mtype.first_field_idx; f < mtype.first_field_idx + mtype.num_fields; f++) { + field_lookup_table[base + field_entries[f].field_number] = f; + } + lookup_offset += table_size; + } +} + +/** + * Recursively build a cudf column for a field in the schema. + * Returns the assembled column. + */ +std::unique_ptr sp_build_column_recursive( + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& col_infos, + std::map const& schema_idx_to_info, + // Buffers (bulk-allocated): + std::vector& col_data_ptrs, // col_idx -> data pointer in bulk buffer + std::vector& col_validity_ptrs, // col_idx -> validity pointer in bulk buffer + std::vector>& list_offsets_bufs, // count_idx -> offsets (top-level) + std::vector& inner_offs_ptrs, // count_idx -> inner offsets pointer (or null) + std::vector& inner_buf_sizes, // count_idx -> inner offsets size (0 if not inner) + std::vector& col_sizes, // col_idx -> element count + std::vector& count_totals, // count_idx -> total count + std::vector& col_elem_bytes, // col_idx -> element byte size + int schema_idx, + int num_fields, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto it = schema_idx_to_info.find(schema_idx); + if (it == schema_idx_to_info.end()) { + return make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); + } + auto const& info = *(it->second); + auto type_id = info.type_id; + bool is_repeated = info.is_repeated; + bool is_string = info.is_string; + + // Determine element count for this column + int elem_count = num_rows; + if (info.parent_count_idx >= 0) { + elem_count = count_totals[info.parent_count_idx]; + } + + if (type_id == cudf::type_id::STRUCT) { + // Find children of this struct + std::vector child_schema_indices; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == schema_idx) child_schema_indices.push_back(i); + } + + if (is_repeated) { + // LIST: build struct children, then wrap in list + int total = count_totals[info.count_idx]; + std::vector> struct_children; + for (int child_si : child_schema_indices) { + struct_children.push_back(sp_build_column_recursive( + schema, schema_output_types, col_infos, schema_idx_to_info, + col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, + col_sizes, count_totals, col_elem_bytes, child_si, num_fields, num_rows, stream, mr)); + } + auto struct_col = cudf::make_structs_column( + total, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); + + // List offsets: use inner offsets for nested repeated fields, top-level offsets otherwise + std::unique_ptr offsets_col; + if (inner_offs_ptrs[info.count_idx] != nullptr) { + int sz = inner_buf_sizes[info.count_idx]; + auto buf = rmm::device_buffer(sz * sizeof(int32_t), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(buf.data(), inner_offs_ptrs[info.count_idx], + sz * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream.value())); + offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, sz, std::move(buf), rmm::device_buffer{}, 0); + } else { + auto& offs = list_offsets_bufs[info.count_idx]; + auto const offs_size = static_cast(offs.size()); + offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, offs_size, offs.release(), rmm::device_buffer{}, 0); + } + + return cudf::make_lists_column( + elem_count, std::move(offsets_col), std::move(struct_col), + 0, rmm::device_buffer{}, stream, mr); + } else { + // Non-repeated struct + std::vector> struct_children; + for (int child_si : child_schema_indices) { + struct_children.push_back(sp_build_column_recursive( + schema, schema_output_types, col_infos, schema_idx_to_info, + col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, + col_sizes, count_totals, col_elem_bytes, child_si, num_fields, num_rows, stream, mr)); + } + // Struct validity from col_idx + int ci = info.col_idx; + if (ci >= 0 && col_validity_ptrs[ci] != nullptr && col_sizes[ci] > 0) { + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(elem_count), + [vld = col_validity_ptrs[ci]] __device__ (cudf::size_type i) { return vld[i]; }, + stream, mr); + return cudf::make_structs_column( + elem_count, std::move(struct_children), null_count, std::move(mask), stream, mr); + } + return cudf::make_structs_column( + elem_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); + } + } + + // Leaf field (scalar or string) + int ci = info.col_idx; + if (ci < 0) { + return make_null_column(schema_output_types[schema_idx], elem_count, stream, mr); + } + + // Helper lambda: build a STRING column from sp_string_pair data + auto build_string_col = [&](int col_idx, int count, bool use_validity) -> std::unique_ptr { + auto* pairs = reinterpret_cast(col_data_ptrs[col_idx]); + rmm::device_uvector str_pairs(count, stream, mr); + if (use_validity) { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), + str_pairs.begin(), + [pairs, vld = col_validity_ptrs[col_idx]] __device__ (int i) -> cudf::strings::detail::string_index_pair { + if (vld[i]) return {pairs[i].ptr, pairs[i].length}; + return {nullptr, 0}; + }); + } else { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), + str_pairs.begin(), + [pairs] __device__ (int i) -> cudf::strings::detail::string_index_pair { + return {pairs[i].ptr, pairs[i].length}; + }); + } + return cudf::strings::detail::make_strings_column( + str_pairs.begin(), str_pairs.end(), stream, mr); + }; + + // Helper lambda: build a LIST (bytes/binary) column from sp_string_pair data + auto build_bytes_col = [&](int col_idx, int count) -> std::unique_ptr { + auto* pairs = reinterpret_cast(col_data_ptrs[col_idx]); + auto* vld = col_validity_ptrs[col_idx]; + // Compute lengths and prefix sum -> offsets (inclusive scan then shift) + rmm::device_uvector byte_offs(count + 1, stream, mr); + if (count > 0) { + // Compute lengths directly into offsets[1..count], then exclusive_scan + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), + byte_offs.begin(), // write to [0..count-1] + [pairs, vld] __device__ (int i) -> int32_t { return vld[i] ? pairs[i].length : 0; }); + thrust::exclusive_scan(rmm::exec_policy(stream), + byte_offs.begin(), byte_offs.begin() + count, byte_offs.begin(), 0); + // Total bytes via transform_reduce (avoids D->H sync for last_off + last_len) + int32_t total_bytes = thrust::transform_reduce(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), + [pairs, vld] __device__ (int i) -> int32_t { + return vld[i] ? pairs[i].length : 0; + }, 0, cuda::std::plus{}); + CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + count, &total_bytes, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + // Copy binary data + rmm::device_uvector child_data(total_bytes > 0 ? total_bytes : 0, stream, mr); + if (total_bytes > 0) { + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), + [pairs, offs = byte_offs.data(), out = child_data.data(), vld] __device__ (int i) { + if (vld[i] && pairs[i].ptr && pairs[i].length > 0) { + memcpy(out + offs[i], pairs[i].ptr, pairs[i].length); + } + }); + } + auto off_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, count + 1, byte_offs.release(), rmm::device_buffer{}, 0); + auto ch_col = std::make_unique(cudf::data_type{cudf::type_id::UINT8}, total_bytes, child_data.release(), rmm::device_buffer{}, 0); + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(count), + [v = vld] __device__ (cudf::size_type i) { return v[i]; }, stream, mr); + return cudf::make_lists_column(count, std::move(off_col), std::move(ch_col), null_count, std::move(mask), stream, mr); + } else { + // Empty bytes column + thrust::fill(rmm::exec_policy(stream), byte_offs.begin(), byte_offs.end(), 0); + auto off_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, count + 1, byte_offs.release(), rmm::device_buffer{}, 0); + auto ch_col = std::make_unique(cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column(count, std::move(off_col), std::move(ch_col), 0, rmm::device_buffer{}, stream, mr); + } + }; + + bool is_bytes = (type_id == cudf::type_id::LIST); + + if (is_repeated) { + // LIST: build child column then wrap in list + int total = count_totals[info.count_idx]; + std::unique_ptr child_col; + + if (is_bytes) { + // repeated bytes -> LIST>: build inner LIST then wrap in outer list + child_col = build_bytes_col(ci, total); + } else if (is_string) { + child_col = build_string_col(ci, total, false); + } else { + auto dt = schema_output_types[schema_idx]; + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total), + [v = col_validity_ptrs[ci]] __device__ (cudf::size_type i) { return v[i]; }, stream, mr); + // Copy data from bulk buffer into a new device_buffer for cudf::column ownership + auto data_buf = rmm::device_buffer(total * col_elem_bytes[ci], stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(data_buf.data(), col_data_ptrs[ci], + total * col_elem_bytes[ci], cudaMemcpyDeviceToDevice, stream.value())); + child_col = std::make_unique( + dt, total, std::move(data_buf), std::move(mask), null_count); + } + + // Use inner offsets if available, else use top-level list offsets + std::unique_ptr offsets_col; + if (inner_offs_ptrs[info.count_idx] != nullptr) { + int sz = inner_buf_sizes[info.count_idx]; + auto buf = rmm::device_buffer(sz * sizeof(int32_t), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(buf.data(), inner_offs_ptrs[info.count_idx], + sz * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream.value())); + offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, sz, std::move(buf), rmm::device_buffer{}, 0); + } else { + auto& offs = list_offsets_bufs[info.count_idx]; + auto const offs_size = static_cast(offs.size()); + offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, offs_size, offs.release(), rmm::device_buffer{}, 0); + } + + return cudf::make_lists_column( + elem_count, std::move(offsets_col), std::move(child_col), + 0, rmm::device_buffer{}, stream, mr); + } + + // Non-repeated leaf + if (is_bytes) { + return build_bytes_col(ci, elem_count); + } + if (is_string) { + return build_string_col(ci, elem_count, true); + } + + // Non-repeated non-string scalar + auto dt = schema_output_types[schema_idx]; + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(elem_count), + [v = col_validity_ptrs[ci]] __device__ (cudf::size_type i) { return v[i]; }, stream, mr); + // Copy data from bulk buffer into a new device_buffer for cudf::column ownership + auto data_buf = rmm::device_buffer(elem_count * col_elem_bytes[ci], stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(data_buf.data(), col_data_ptrs[ci], + elem_count * col_elem_bytes[ci], cudaMemcpyDeviceToDevice, stream.value())); + return std::make_unique( + dt, elem_count, std::move(data_buf), std::move(mask), null_count); +} + +/** + * Main single-pass decoder orchestration. + */ +std::unique_ptr decode_nested_protobuf_single_pass( + cudf::column_view const& binary_input, + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + bool fail_on_errors) +{ + auto const stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + auto num_rows = binary_input.size(); + auto num_fields = static_cast(schema.size()); + + // Timing instrumentation (enabled by PROTOBUF_SP_TIMING=1) + static bool sp_timing_enabled = (std::getenv("PROTOBUF_SP_TIMING") != nullptr && + std::string(std::getenv("PROTOBUF_SP_TIMING")) == "1"); + static int sp_call_count = 0; + static double sp_phase_totals[8] = {}; // accumulate across calls + cudaEvent_t t_start, t1, t2, t3, t4, t5, t6, t7; + if (sp_timing_enabled) { + cudaEventCreate(&t_start); cudaEventCreate(&t1); cudaEventCreate(&t2); + cudaEventCreate(&t3); cudaEventCreate(&t4); cudaEventCreate(&t5); + cudaEventCreate(&t6); cudaEventCreate(&t7); + cudaEventRecord(t_start, stream.value()); + } + + // === Phase 1: Schema Preprocessing === + std::vector h_msg_types; + std::vector h_field_entries; + std::vector col_infos; + std::vector h_field_lookup; + int num_count_cols = 0; + int num_output_cols = 0; + + build_single_pass_schema(schema, schema_output_types, + default_ints, default_floats, default_bools, + h_msg_types, h_field_entries, col_infos, h_field_lookup, + num_count_cols, num_output_cols); + + // Check limits + if (num_count_cols > SP_MAX_COUNTED || num_output_cols > SP_MAX_OUTPUT_COLS) { + return nullptr; // Signal caller to fall back to old decoder + } + + // Copy schema to device + rmm::device_uvector d_msg_types(h_msg_types.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_msg_types.data(), h_msg_types.data(), + h_msg_types.size() * sizeof(sp_msg_type), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector d_field_entries(h_field_entries.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_entries.data(), h_field_entries.data(), + h_field_entries.size() * sizeof(sp_field_entry), cudaMemcpyHostToDevice, stream.value())); + + // Copy O(1) field lookup table to device + rmm::device_uvector d_field_lookup(h_field_lookup.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_lookup.data(), h_field_lookup.data(), + h_field_lookup.size() * sizeof(int), cudaMemcpyHostToDevice, stream.value())); + + auto d_in = cudf::column_device_view::create(binary_input, stream); + + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + int const threads = 256; + int const blocks = (num_rows + threads - 1) / threads; + + if (sp_timing_enabled) cudaEventRecord(t1, stream.value()); // end schema prep + alloc + + // === Phase 2: Pass 1 - Count === + rmm::device_uvector d_counts( + num_count_cols > 0 ? static_cast(num_rows) * num_count_cols : 1, stream, mr); + if (num_count_cols > 0) { + CUDF_CUDA_TRY(cudaMemsetAsync(d_counts.data(), 0, + static_cast(num_rows) * num_count_cols * sizeof(int32_t), stream.value())); + } + + sp_unified_count_kernel<<>>( + *d_in, d_msg_types.data(), d_field_entries.data(), d_field_lookup.data(), + d_counts.data(), num_count_cols, d_error.data()); + + if (sp_timing_enabled) cudaEventRecord(t2, stream.value()); // end count kernel + + // === Phase 3: Compute Offsets and Allocate Buffers === + // Fused: compute all prefix sums + list offsets in a SINGLE kernel launch. + // Replaces ~50 syncs + ~200 kernel launches with 1 kernel + 1 sync. + rmm::device_uvector d_row_offsets( + num_count_cols > 0 ? static_cast(num_rows) * num_count_cols : 1, stream, mr); + + std::vector count_totals(num_count_cols, 0); + std::vector> list_offsets_bufs; + list_offsets_bufs.reserve(num_count_cols); + + // Pre-allocate all list offset buffers and collect device pointers + std::vector h_list_offs_ptrs(num_count_cols, nullptr); + for (int c = 0; c < num_count_cols; c++) { + list_offsets_bufs.emplace_back(num_rows + 1, stream, mr); + h_list_offs_ptrs[c] = list_offsets_bufs.back().data(); + } + + if (num_count_cols > 0) { + // Copy list offset pointers to device + rmm::device_uvector d_list_offs_ptrs(num_count_cols, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_list_offs_ptrs.data(), h_list_offs_ptrs.data(), + num_count_cols * sizeof(int32_t*), cudaMemcpyHostToDevice, stream.value())); + + // Device buffer for totals + rmm::device_uvector d_totals(num_count_cols, stream, mr); + + // Single fused kernel: prefix sums + totals + list offsets for all columns + int const off_threads = std::min(num_count_cols, 256); + int const off_blocks = (num_count_cols + off_threads - 1) / off_threads; + sp_compute_offsets_kernel<<>>( + d_counts.data(), d_row_offsets.data(), d_totals.data(), + d_list_offs_ptrs.data(), num_rows, num_count_cols); + + // Single D->H copy for all totals + single sync + CUDF_CUDA_TRY(cudaMemcpyAsync(count_totals.data(), d_totals.data(), + num_count_cols * sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + } + + // Build schema_idx -> col_info lookup + std::map schema_idx_to_info; + for (auto const& ci : col_infos) { + schema_idx_to_info[ci.schema_idx] = &ci; + } + + // Determine buffer sizes for each col_idx + std::vector col_sizes(num_output_cols, 0); + for (auto const& ci : col_infos) { + if (ci.col_idx < 0) continue; + if (ci.is_repeated && ci.count_idx >= 0) { + col_sizes[ci.col_idx] = count_totals[ci.count_idx]; + } else if (ci.parent_count_idx >= 0) { + col_sizes[ci.col_idx] = count_totals[ci.parent_count_idx]; + } else { + col_sizes[ci.col_idx] = num_rows; + } + } + + // Build col_idx -> col_info lookup (avoids O(N^2) inner loop) + std::vector col_idx_to_info(num_output_cols, nullptr); + for (auto const& c : col_infos) { + if (c.col_idx >= 0) col_idx_to_info[c.col_idx] = &c; + } + + // Compute per-column element sizes and total buffer sizes for BULK allocation. + // Replaces ~992 individual RMM allocations + ~992 memsets with 2 allocs + 2 memsets. + std::vector col_elem_bytes(num_output_cols, 0); + std::vector col_data_offsets(num_output_cols, 0); + std::vector col_validity_offsets(num_output_cols, 0); + size_t total_data_bytes = 0; + size_t total_validity_elems = 0; + + for (int ci_idx = 0; ci_idx < num_output_cols; ci_idx++) { + auto const* cinfo = col_idx_to_info[ci_idx]; + size_t eb = 0; + if (cinfo) { + auto tid = cinfo->type_id; + if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) { + eb = sizeof(sp_string_pair); + } else if (tid == cudf::type_id::STRUCT) { + eb = 0; + } else { + eb = cudf::size_of(cudf::data_type{tid}); + } + } + col_elem_bytes[ci_idx] = eb; + int32_t sz = col_sizes[ci_idx]; + // Align data offset to 16 bytes for coalesced GPU access + col_data_offsets[ci_idx] = total_data_bytes; + total_data_bytes += (sz > 0 ? sz * eb : 0); + total_data_bytes = (total_data_bytes + 15) & ~size_t{15}; // 16-byte align + + col_validity_offsets[ci_idx] = total_validity_elems; + total_validity_elems += (sz > 0 ? sz : 0); + } + + // TWO bulk allocations instead of ~992 individual ones + rmm::device_uvector bulk_data(total_data_bytes > 0 ? total_data_bytes : 1, stream, mr); + rmm::device_uvector bulk_validity(total_validity_elems > 0 ? total_validity_elems : 1, stream, mr); + + // TWO bulk memsets instead of ~992 individual ones + if (total_data_bytes > 0) { + CUDF_CUDA_TRY(cudaMemsetAsync(bulk_data.data(), 0, total_data_bytes, stream.value())); + } + if (total_validity_elems > 0) { + CUDF_CUDA_TRY(cudaMemsetAsync(bulk_validity.data(), 0, total_validity_elems * sizeof(bool), stream.value())); + } + + // Per-column pointers into the bulk buffers + std::vector col_data_ptrs(num_output_cols, nullptr); + std::vector col_validity_ptrs(num_output_cols, nullptr); + + for (int ci_idx = 0; ci_idx < num_output_cols; ci_idx++) { + int32_t sz = col_sizes[ci_idx]; + if (sz > 0) { + col_data_ptrs[ci_idx] = bulk_data.data() + col_data_offsets[ci_idx]; + col_validity_ptrs[ci_idx] = bulk_validity.data() + col_validity_offsets[ci_idx]; + } + } + + // Fill non-zero defaults (rare - proto3 defaults are all 0) + for (int ci_idx = 0; ci_idx < num_output_cols; ci_idx++) { + auto const* cinfo = col_idx_to_info[ci_idx]; + int32_t sz = col_sizes[ci_idx]; + if (!cinfo || sz <= 0) continue; + if (cinfo->type_id == cudf::type_id::STRING || + cinfo->type_id == cudf::type_id::LIST || + cinfo->type_id == cudf::type_id::STRUCT) continue; + int si = cinfo->schema_idx; + if (!schema[si].has_default_value) continue; + + auto tid = cinfo->type_id; + bool non_zero = false; + if (tid == cudf::type_id::BOOL8) non_zero = default_bools[si]; + else if (tid == cudf::type_id::FLOAT32 || tid == cudf::type_id::FLOAT64) + non_zero = (default_floats[si] != 0.0); + else non_zero = (default_ints[si] != 0); + + if (non_zero) { + thrust::fill_n(rmm::exec_policy(stream), col_validity_ptrs[ci_idx], sz, true); + auto* dp = col_data_ptrs[ci_idx]; + if (tid == cudf::type_id::BOOL8) + thrust::fill_n(rmm::exec_policy(stream), dp, sz, static_cast(1)); + else if (tid == cudf::type_id::INT32 || tid == cudf::type_id::UINT32) + thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, static_cast(default_ints[si])); + else if (tid == cudf::type_id::INT64 || tid == cudf::type_id::UINT64) + thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, default_ints[si]); + else if (tid == cudf::type_id::FLOAT32) + thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, static_cast(default_floats[si])); + else if (tid == cudf::type_id::FLOAT64) + thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, default_floats[si]); + } + } + + // Build device-side column descriptors (using bulk buffer pointers) + std::vector h_col_descs(num_output_cols); + for (int i = 0; i < num_output_cols; i++) { + h_col_descs[i].data = col_data_ptrs[i]; + h_col_descs[i].validity = col_validity_ptrs[i]; + } + rmm::device_uvector d_col_descs(num_output_cols, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_col_descs.data(), h_col_descs.data(), + num_output_cols * sizeof(sp_col_desc), cudaMemcpyHostToDevice, stream.value())); + + // Allocate parent index buffers for inner repeated fields + std::vector> parent_idx_storage; + std::vector h_parent_bufs(num_count_cols, nullptr); + parent_idx_storage.reserve(num_count_cols); + + for (auto const& ci : col_infos) { + if (ci.count_idx >= 0 && ci.parent_count_idx >= 0) { + // Inner repeated field: needs parent index buffer + int total = count_totals[ci.count_idx]; + parent_idx_storage.emplace_back(total > 0 ? total : 0, stream, mr); + h_parent_bufs[ci.count_idx] = parent_idx_storage.back().data(); + } + } + + rmm::device_uvector d_parent_bufs_arr(num_count_cols, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_bufs_arr.data(), h_parent_bufs.data(), + num_count_cols * sizeof(int32_t*), cudaMemcpyHostToDevice, stream.value())); + + if (sp_timing_enabled) cudaEventRecord(t3, stream.value()); // end offsets + buffer alloc + + // === Phase 4: Pass 2 - Extract === + sp_unified_extract_kernel<<>>( + *d_in, d_msg_types.data(), d_field_entries.data(), d_field_lookup.data(), + d_col_descs.data(), d_row_offsets.data(), + d_parent_bufs_arr.data(), num_count_cols, d_error.data()); + + if (sp_timing_enabled) cudaEventRecord(t4, stream.value()); // end extract kernel + + // === Phase 5: Compute Inner List Offsets === + // Pre-compute which count columns are inner (parent_count_idx >= 0) and their sizes. + // Bulk-allocate a single buffer for all inner offsets to avoid memory pool fragmentation. + struct inner_info_t { int count_idx; int parent_count_idx; int total_child; int total_parent; }; + std::vector inner_infos; + size_t total_inner_elems = 0; + std::vector inner_buf_offsets(num_count_cols, -1); // offset into bulk inner buffer + std::vector inner_buf_sizes(num_count_cols, 0); + + for (int c = 0; c < num_count_cols; c++) { + sp_host_col_info const* cinfo_ptr = nullptr; + for (auto const& ci : col_infos) { + if (ci.count_idx == c) { cinfo_ptr = &ci; break; } + } + if (cinfo_ptr && cinfo_ptr->parent_count_idx >= 0) { + int total_child = count_totals[c]; + int total_parent = count_totals[cinfo_ptr->parent_count_idx]; + int sz = total_parent + 1; + inner_buf_offsets[c] = static_cast(total_inner_elems); + inner_buf_sizes[c] = sz; + total_inner_elems += sz; + inner_infos.push_back({c, cinfo_ptr->parent_count_idx, total_child, total_parent}); + } + } + + // Single bulk allocation for all inner offsets + rmm::device_uvector bulk_inner_offsets( + total_inner_elems > 0 ? total_inner_elems : 1, stream, mr); + if (total_inner_elems > 0) { + CUDF_CUDA_TRY(cudaMemsetAsync(bulk_inner_offsets.data(), 0, + total_inner_elems * sizeof(int32_t), stream.value())); + } + + // Compute inner offsets via lower_bound (only for non-empty inner fields) + for (auto const& ii : inner_infos) { + int c = ii.count_idx; + int32_t* out = bulk_inner_offsets.data() + inner_buf_offsets[c]; + if (ii.total_child > 0 && ii.total_parent > 0 && h_parent_bufs[c] != nullptr) { + thrust::lower_bound(rmm::exec_policy(stream), + h_parent_bufs[c], h_parent_bufs[c] + ii.total_child, + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(ii.total_parent + 1), + out); + } + // else: already zeroed by memset + } + + // Build inner offset pointer array (points into bulk buffer, no per-column allocation) + std::vector inner_offs_ptrs(num_count_cols, nullptr); + for (int c = 0; c < num_count_cols; c++) { + if (inner_buf_offsets[c] >= 0) { + inner_offs_ptrs[c] = bulk_inner_offsets.data() + inner_buf_offsets[c]; + } + } + + if (sp_timing_enabled) cudaEventRecord(t5, stream.value()); // end inner offsets + + // === Phase 6: Column Assembly === + // Check for errors + CUDF_CUDA_TRY(cudaPeekAtLastError()); + int h_error = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + if (fail_on_errors) { + CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); + } + + // Build top-level struct column + std::vector> top_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + top_children.push_back(sp_build_column_recursive( + schema, schema_output_types, col_infos, schema_idx_to_info, + col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, + col_sizes, count_totals, col_elem_bytes, i, num_fields, num_rows, stream, mr)); + } + } + + auto result = cudf::make_structs_column( + num_rows, std::move(top_children), 0, rmm::device_buffer{}, stream, mr); + + // Print timing results + if (sp_timing_enabled) { + cudaEventRecord(t6, stream.value()); // end column assembly + cudaEventSynchronize(t6); + + float ms[7]; + cudaEventElapsedTime(&ms[0], t_start, t1); // schema prep + device copy + cudaEventElapsedTime(&ms[1], t1, t2); // count kernel + cudaEventElapsedTime(&ms[2], t2, t3); // offsets + buffer alloc + cudaEventElapsedTime(&ms[3], t3, t4); // extract kernel + cudaEventElapsedTime(&ms[4], t4, t5); // inner offsets + cudaEventElapsedTime(&ms[5], t5, t6); // column assembly + cudaEventElapsedTime(&ms[6], t_start, t6); // total + + sp_call_count++; + sp_phase_totals[0] += ms[0]; sp_phase_totals[1] += ms[1]; + sp_phase_totals[2] += ms[2]; sp_phase_totals[3] += ms[3]; + sp_phase_totals[4] += ms[4]; sp_phase_totals[5] += ms[5]; + sp_phase_totals[6] += ms[6]; + + if (sp_call_count % 50 == 0) { + fprintf(stderr, + "[SP-TIMING] call#%d rows=%d fields=%d count_cols=%d out_cols=%d | " + "THIS: prep=%.1f count=%.1f offsets=%.1f extract=%.1f inner=%.1f assembly=%.1f TOTAL=%.1f ms | " + "CUMUL: prep=%.0f count=%.0f offsets=%.0f extract=%.0f inner=%.0f assembly=%.0f TOTAL=%.0f ms\n", + sp_call_count, num_rows, num_fields, num_count_cols, num_output_cols, + ms[0], ms[1], ms[2], ms[3], ms[4], ms[5], ms[6], + sp_phase_totals[0], sp_phase_totals[1], sp_phase_totals[2], + sp_phase_totals[3], sp_phase_totals[4], sp_phase_totals[5], + sp_phase_totals[6]); + } + + cudaEventDestroy(t_start); cudaEventDestroy(t1); cudaEventDestroy(t2); + cudaEventDestroy(t3); cudaEventDestroy(t4); cudaEventDestroy(t5); + cudaEventDestroy(t6); cudaEventDestroy(t7); + } + + return result; +} + } // anonymous namespace std::unique_ptr decode_nested_protobuf_to_struct( @@ -3644,6 +4999,26 @@ std::unique_ptr decode_nested_protobuf_to_struct( return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } + // Try single-pass decoder (faster for complex nested schemas) + // Can be disabled by setting PROTOBUF_NO_SINGLE_PASS=1 + { + char const* no_sp = std::getenv("PROTOBUF_NO_SINGLE_PASS"); + bool use_single_pass = !(no_sp && std::string(no_sp) == "1"); + if (use_single_pass) { + auto result = decode_nested_protobuf_single_pass( + binary_input, schema, schema_output_types, + default_ints, default_floats, default_bools, default_strings, + fail_on_errors); + CUDF_EXPECTS(result != nullptr, + "Single-pass protobuf decoder failed: schema exceeds limits " + "(SP_MAX_COUNTED=" + std::to_string(SP_MAX_COUNTED) + + " or SP_MAX_OUTPUT_COLS=" + std::to_string(SP_MAX_OUTPUT_COLS) + + ", actual num_count_cols or num_output_cols too large). " + "Set PROTOBUF_NO_SINGLE_PASS=1 to use the old decoder."); + return result; + } + } + // Copy schema to device std::vector h_device_schema(num_fields); for (int i = 0; i < num_fields; i++) { From c52782c2ca29a99ef6b8b56b55c5dcf0c71137fd Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 9 Feb 2026 10:36:52 +0800 Subject: [PATCH 018/107] delete debug log Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 64 +----------------------------------- 1 file changed, 1 insertion(+), 63 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a762ffc1cd..c6abb11682 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -4562,19 +4562,6 @@ std::unique_ptr decode_nested_protobuf_single_pass( auto num_rows = binary_input.size(); auto num_fields = static_cast(schema.size()); - // Timing instrumentation (enabled by PROTOBUF_SP_TIMING=1) - static bool sp_timing_enabled = (std::getenv("PROTOBUF_SP_TIMING") != nullptr && - std::string(std::getenv("PROTOBUF_SP_TIMING")) == "1"); - static int sp_call_count = 0; - static double sp_phase_totals[8] = {}; // accumulate across calls - cudaEvent_t t_start, t1, t2, t3, t4, t5, t6, t7; - if (sp_timing_enabled) { - cudaEventCreate(&t_start); cudaEventCreate(&t1); cudaEventCreate(&t2); - cudaEventCreate(&t3); cudaEventCreate(&t4); cudaEventCreate(&t5); - cudaEventCreate(&t6); cudaEventCreate(&t7); - cudaEventRecord(t_start, stream.value()); - } - // === Phase 1: Schema Preprocessing === std::vector h_msg_types; std::vector h_field_entries; @@ -4615,8 +4602,6 @@ std::unique_ptr decode_nested_protobuf_single_pass( int const threads = 256; int const blocks = (num_rows + threads - 1) / threads; - if (sp_timing_enabled) cudaEventRecord(t1, stream.value()); // end schema prep + alloc - // === Phase 2: Pass 1 - Count === rmm::device_uvector d_counts( num_count_cols > 0 ? static_cast(num_rows) * num_count_cols : 1, stream, mr); @@ -4629,8 +4614,6 @@ std::unique_ptr decode_nested_protobuf_single_pass( *d_in, d_msg_types.data(), d_field_entries.data(), d_field_lookup.data(), d_counts.data(), num_count_cols, d_error.data()); - if (sp_timing_enabled) cudaEventRecord(t2, stream.value()); // end count kernel - // === Phase 3: Compute Offsets and Allocate Buffers === // Fused: compute all prefix sums + list offsets in a SINGLE kernel launch. // Replaces ~50 syncs + ~200 kernel launches with 1 kernel + 1 sync. @@ -4813,16 +4796,12 @@ std::unique_ptr decode_nested_protobuf_single_pass( CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_bufs_arr.data(), h_parent_bufs.data(), num_count_cols * sizeof(int32_t*), cudaMemcpyHostToDevice, stream.value())); - if (sp_timing_enabled) cudaEventRecord(t3, stream.value()); // end offsets + buffer alloc - // === Phase 4: Pass 2 - Extract === sp_unified_extract_kernel<<>>( *d_in, d_msg_types.data(), d_field_entries.data(), d_field_lookup.data(), d_col_descs.data(), d_row_offsets.data(), d_parent_bufs_arr.data(), num_count_cols, d_error.data()); - if (sp_timing_enabled) cudaEventRecord(t4, stream.value()); // end extract kernel - // === Phase 5: Compute Inner List Offsets === // Pre-compute which count columns are inner (parent_count_idx >= 0) and their sizes. // Bulk-allocate a single buffer for all inner offsets to avoid memory pool fragmentation. @@ -4878,8 +4857,6 @@ std::unique_ptr decode_nested_protobuf_single_pass( } } - if (sp_timing_enabled) cudaEventRecord(t5, stream.value()); // end inner offsets - // === Phase 6: Column Assembly === // Check for errors CUDF_CUDA_TRY(cudaPeekAtLastError()); @@ -4902,47 +4879,8 @@ std::unique_ptr decode_nested_protobuf_single_pass( } } - auto result = cudf::make_structs_column( + return cudf::make_structs_column( num_rows, std::move(top_children), 0, rmm::device_buffer{}, stream, mr); - - // Print timing results - if (sp_timing_enabled) { - cudaEventRecord(t6, stream.value()); // end column assembly - cudaEventSynchronize(t6); - - float ms[7]; - cudaEventElapsedTime(&ms[0], t_start, t1); // schema prep + device copy - cudaEventElapsedTime(&ms[1], t1, t2); // count kernel - cudaEventElapsedTime(&ms[2], t2, t3); // offsets + buffer alloc - cudaEventElapsedTime(&ms[3], t3, t4); // extract kernel - cudaEventElapsedTime(&ms[4], t4, t5); // inner offsets - cudaEventElapsedTime(&ms[5], t5, t6); // column assembly - cudaEventElapsedTime(&ms[6], t_start, t6); // total - - sp_call_count++; - sp_phase_totals[0] += ms[0]; sp_phase_totals[1] += ms[1]; - sp_phase_totals[2] += ms[2]; sp_phase_totals[3] += ms[3]; - sp_phase_totals[4] += ms[4]; sp_phase_totals[5] += ms[5]; - sp_phase_totals[6] += ms[6]; - - if (sp_call_count % 50 == 0) { - fprintf(stderr, - "[SP-TIMING] call#%d rows=%d fields=%d count_cols=%d out_cols=%d | " - "THIS: prep=%.1f count=%.1f offsets=%.1f extract=%.1f inner=%.1f assembly=%.1f TOTAL=%.1f ms | " - "CUMUL: prep=%.0f count=%.0f offsets=%.0f extract=%.0f inner=%.0f assembly=%.0f TOTAL=%.0f ms\n", - sp_call_count, num_rows, num_fields, num_count_cols, num_output_cols, - ms[0], ms[1], ms[2], ms[3], ms[4], ms[5], ms[6], - sp_phase_totals[0], sp_phase_totals[1], sp_phase_totals[2], - sp_phase_totals[3], sp_phase_totals[4], sp_phase_totals[5], - sp_phase_totals[6]); - } - - cudaEventDestroy(t_start); cudaEventDestroy(t1); cudaEventDestroy(t2); - cudaEventDestroy(t3); cudaEventDestroy(t4); cudaEventDestroy(t5); - cudaEventDestroy(t6); cudaEventDestroy(t7); - } - - return result; } } // anonymous namespace From 47973a737cf56d2cf4f5ddae0dd82025bea13f7f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 10 Feb 2026 10:52:22 +0800 Subject: [PATCH 019/107] check point before schema projection option A, PROTOBUF_SINGLE_PASS=1 is slower now Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 51 +++++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index c6abb11682..be27bcc656 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -4869,13 +4869,26 @@ std::unique_ptr decode_nested_protobuf_single_pass( } // Build top-level struct column + // For top-level repeated (LIST) columns, propagate input binary null mask. + // In protobuf: absent repeated field = [] (empty array), but null input row = null LIST. + // The old decoder did this via cudf::copy_bitmask(binary_input). We do the same here. + auto const input_null_count = binary_input.null_count(); + std::vector> top_children; for (int i = 0; i < num_fields; i++) { if (schema[i].parent_idx == -1) { - top_children.push_back(sp_build_column_recursive( + auto col = sp_build_column_recursive( schema, schema_output_types, col_infos, schema_idx_to_info, col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, - col_sizes, count_totals, col_elem_bytes, i, num_fields, num_rows, stream, mr)); + col_sizes, count_totals, col_elem_bytes, i, num_fields, num_rows, stream, mr); + + // Apply input null mask to top-level LIST columns (repeated fields) + if (input_null_count > 0 && schema[i].is_repeated) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + col->set_null_mask(std::move(null_mask), input_null_count); + } + + top_children.push_back(std::move(col)); } } @@ -4937,23 +4950,35 @@ std::unique_ptr decode_nested_protobuf_to_struct( return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } - // Try single-pass decoder (faster for complex nested schemas) - // Can be disabled by setting PROTOBUF_NO_SINGLE_PASS=1 + // Choose decoder based on schema complexity. + // Single-pass decoder: fewer kernel launches but expensive column assembly. + // Old per-field decoder: more kernel launches but simpler assembly. + // For large schemas (>100 output cols), the old decoder is faster because + // single-pass assembly creates hundreds of cudf columns in one batch. + // Can override with PROTOBUF_SINGLE_PASS=1 (force single-pass) or =0 (force old). { - char const* no_sp = std::getenv("PROTOBUF_NO_SINGLE_PASS"); - bool use_single_pass = !(no_sp && std::string(no_sp) == "1"); + char const* sp_env = std::getenv("PROTOBUF_SINGLE_PASS"); + bool force_sp = (sp_env && std::string(sp_env) == "1"); + bool force_old = (sp_env && std::string(sp_env) == "0"); + + // Count output columns (non-repeated leaf fields + struct containers) + int output_col_count = 0; + for (int i = 0; i < num_fields; i++) { + if (schema_output_types[i].id() != cudf::type_id::STRUCT || !schema[i].is_repeated) { + output_col_count++; + } + } + + // Auto-select: use single-pass for small schemas, old decoder for large ones + bool use_single_pass = force_sp || (!force_old && output_col_count <= 100); + if (use_single_pass) { auto result = decode_nested_protobuf_single_pass( binary_input, schema, schema_output_types, default_ints, default_floats, default_bools, default_strings, fail_on_errors); - CUDF_EXPECTS(result != nullptr, - "Single-pass protobuf decoder failed: schema exceeds limits " - "(SP_MAX_COUNTED=" + std::to_string(SP_MAX_COUNTED) + - " or SP_MAX_OUTPUT_COLS=" + std::to_string(SP_MAX_OUTPUT_COLS) + - ", actual num_count_cols or num_output_cols too large). " - "Set PROTOBUF_NO_SINGLE_PASS=1 to use the old decoder."); - return result; + if (result) return result; + // Fall through to old decoder if single-pass returns null (exceeds limits) } } From d2595b12edaf260cfa56aeaf217342e97fa748ee Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 11 Feb 2026 10:08:41 +0800 Subject: [PATCH 020/107] Merge flat and nested api for from_protobuf Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 148 +- src/main/cpp/src/protobuf.cu | 2220 +---------------- src/main/cpp/src/protobuf.hpp | 79 +- .../com/nvidia/spark/rapids/jni/Protobuf.java | 390 +-- .../nvidia/spark/rapids/jni/ProtobufTest.java | 152 +- 5 files changed, 224 insertions(+), 2765 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index f626ff291e..d40f4c0512 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -25,152 +25,6 @@ extern "C" { JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, - jclass, - jlong binary_input_view, - jint total_num_fields, - jintArray decoded_field_indices, - jintArray field_numbers, - jintArray all_type_ids, - jintArray encodings, - jbooleanArray is_required, - jbooleanArray has_default_value, - jlongArray default_ints, - jdoubleArray default_floats, - jbooleanArray default_bools, - jobjectArray default_strings, - jobjectArray enum_valid_values, - jboolean fail_on_errors) -{ - JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); - JNI_NULL_CHECK(env, decoded_field_indices, "decoded_field_indices is null", 0); - JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); - JNI_NULL_CHECK(env, all_type_ids, "all_type_ids is null", 0); - JNI_NULL_CHECK(env, encodings, "encodings is null", 0); - JNI_NULL_CHECK(env, is_required, "is_required is null", 0); - JNI_NULL_CHECK(env, has_default_value, "has_default_value is null", 0); - JNI_NULL_CHECK(env, default_ints, "default_ints is null", 0); - JNI_NULL_CHECK(env, default_floats, "default_floats is null", 0); - JNI_NULL_CHECK(env, default_bools, "default_bools is null", 0); - JNI_NULL_CHECK(env, default_strings, "default_strings is null", 0); - JNI_NULL_CHECK(env, enum_valid_values, "enum_valid_values is null", 0); - - JNI_TRY - { - cudf::jni::auto_set_device(env); - auto const* input = reinterpret_cast(binary_input_view); - - cudf::jni::native_jintArray n_decoded_indices(env, decoded_field_indices); - cudf::jni::native_jintArray n_field_numbers(env, field_numbers); - cudf::jni::native_jintArray n_all_type_ids(env, all_type_ids); - cudf::jni::native_jintArray n_encodings(env, encodings); - cudf::jni::native_jbooleanArray n_is_required(env, is_required); - cudf::jni::native_jbooleanArray n_has_default(env, has_default_value); - cudf::jni::native_jlongArray n_default_ints(env, default_ints); - cudf::jni::native_jdoubleArray n_default_floats(env, default_floats); - cudf::jni::native_jbooleanArray n_default_bools(env, default_bools); - - int num_decoded_fields = n_decoded_indices.size(); - - // Validate array sizes - if (n_field_numbers.size() != num_decoded_fields || n_encodings.size() != num_decoded_fields || - n_is_required.size() != num_decoded_fields || n_has_default.size() != num_decoded_fields || - n_default_ints.size() != num_decoded_fields || - n_default_floats.size() != num_decoded_fields || - n_default_bools.size() != num_decoded_fields) { - JNI_THROW_NEW(env, - cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "All decoded field arrays must have the same length", - 0); - } - if (n_all_type_ids.size() != total_num_fields) { - JNI_THROW_NEW(env, - cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "all_type_ids size must equal total_num_fields", - 0); - } - - std::vector decoded_indices(n_decoded_indices.begin(), n_decoded_indices.end()); - std::vector field_nums(n_field_numbers.begin(), n_field_numbers.end()); - std::vector encs(n_encodings.begin(), n_encodings.end()); - - // Convert jboolean arrays to std::vector - std::vector required_flags; - std::vector has_default_flags; - std::vector default_bool_values; - required_flags.reserve(num_decoded_fields); - has_default_flags.reserve(num_decoded_fields); - default_bool_values.reserve(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; ++i) { - required_flags.push_back(n_is_required[i] != 0); - has_default_flags.push_back(n_has_default[i] != 0); - default_bool_values.push_back(n_default_bools[i] != 0); - } - - // Convert default int/float values - std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); - std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); - - // Convert default string values (byte[][] -> vector>) - std::vector> default_string_values; - default_string_values.reserve(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; ++i) { - jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(default_strings, i)); - if (byte_arr == nullptr) { - default_string_values.emplace_back(); // empty vector for null - } else { - jsize len = env->GetArrayLength(byte_arr); - jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); - default_string_values.emplace_back(reinterpret_cast(bytes), - reinterpret_cast(bytes) + len); - env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); - } - } - - // Convert enum valid values (int[][] -> vector>) - // Each element is either null (not an enum field) or an array of valid enum values - std::vector> enum_values; - enum_values.reserve(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; ++i) { - jintArray int_arr = static_cast(env->GetObjectArrayElement(enum_valid_values, i)); - if (int_arr == nullptr) { - enum_values.emplace_back(); // empty vector for null (not an enum field) - } else { - jsize len = env->GetArrayLength(int_arr); - jint* ints = env->GetIntArrayElements(int_arr, nullptr); - enum_values.emplace_back(ints, ints + len); - env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); - } - } - - // Build all_types vector - types for ALL fields in the output struct - std::vector all_types; - all_types.reserve(total_num_fields); - for (int i = 0; i < total_num_fields; ++i) { - // For non-decimal types, scale is always 0 - all_types.emplace_back(cudf::jni::make_data_type(n_all_type_ids[i], 0)); - } - - auto result = spark_rapids_jni::decode_protobuf_to_struct(*input, - total_num_fields, - decoded_indices, - field_nums, - all_types, - encs, - required_flags, - has_default_flags, - default_int_values, - default_float_values, - default_bool_values, - default_string_values, - enum_values, - fail_on_errors); - return cudf::jni::release_as_jlong(result); - } - JNI_CATCH(env, 0); -} - -JNIEXPORT jlong JNICALL -Java_com_nvidia_spark_rapids_jni_Protobuf_decodeNestedToStruct(JNIEnv* env, jclass, jlong binary_input_view, jintArray field_numbers, @@ -309,7 +163,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeNestedToStruct(JNIEnv* env, } } - auto result = spark_rapids_jni::decode_nested_protobuf_to_struct( + auto result = spark_rapids_jni::decode_protobuf_to_struct( *input, schema, schema_output_types, diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index be27bcc656..dfd59a6ff9 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -32,18 +32,14 @@ #include #include -#include #include #include #include #include #include #include -#include #include -#include -#include #include #include #include @@ -109,61 +105,6 @@ struct device_nested_field_descriptor { bool has_default_value; }; -// ============================================================================ -// Single-pass decoder data structures -// ============================================================================ - -/// Maximum nesting depth for single-pass decoder -constexpr int SP_MAX_DEPTH = 10; - -/// Maximum number of counted columns (repeated fields at all depths) -constexpr int SP_MAX_COUNTED = 128; - -/// Maximum number of output columns -constexpr int SP_MAX_OUTPUT_COLS = 512; - -/// Message type descriptor: groups fields belonging to the same protobuf message -struct sp_msg_type { - int first_field_idx; // Start index in the global sp_field_entry array - int num_fields; // Number of direct child fields - int lookup_offset; // Offset into d_field_lookup table (-1 if not using lookup) - int max_field_number; // Max field number + 1 (size of lookup region) -}; - -/// Field entry for single-pass decoder (device-side, sorted by field_number per msg type) -struct sp_field_entry { - int field_number; // Protobuf field number - int wire_type; // Expected wire type - int output_type_id; // cudf type_id cast to int (-1 for struct containers) - int encoding; // ENC_DEFAULT / ENC_FIXED / ENC_ZIGZAG - int child_msg_type; // For nested messages: index into sp_msg_type (-1 otherwise) - int col_idx; // Index into output column descriptors (-1 for containers) - int count_idx; // For repeated fields: index into per-row count array (-1 if not) - bool is_repeated; // Whether this field is repeated - bool has_default; // Whether this field has a default value - int64_t default_int; // Default value for int/long/bool - double default_float; // Default value for float/double -}; - -/// Stack entry for nested message parsing within a kernel thread -struct sp_stack_entry { - int parent_end_offset; // End offset of parent message (relative to row start) - int msg_type_idx; // Saved message type index - int write_base; // Saved write base for non-repeated children -}; - -/// Output column descriptor (device-side, used during Pass 2) -struct sp_col_desc { - void* data; // Typed data buffer (or string_index_pair* for strings) - bool* validity; // Validity buffer (one bool per element) -}; - -/// Pair for zero-copy string references (device-side) -struct sp_string_pair { - char const* ptr; // Pointer into message data (null if not found) - int32_t length; // String length in bytes (0 if not found) -}; - // ============================================================================ // Device helper functions // ============================================================================ @@ -820,77 +761,6 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, valid[row] = true; } -/** - * Kernel to copy variable-length data (string/bytes) to output buffer. - * Uses pre-computed output offsets from prefix sum. - * Supports default values for missing fields. - */ -__global__ void copy_varlen_data_kernel( - uint8_t const* message_data, - cudf::size_type const* input_offsets, // List offsets for input rows - cudf::size_type base_offset, - field_location const* locations, - int field_idx, - int num_fields, - int32_t const* output_offsets, // Pre-computed output offsets (prefix sum) - char* output_data, - int num_rows, - bool has_default = false, - uint8_t const* default_data = nullptr, - int32_t default_length = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto loc = locations[row * num_fields + field_idx]; - char* dst = output_data + output_offsets[row]; - - if (loc.offset < 0) { - // Field not found - use default if available - if (has_default && default_length > 0) { - for (int i = 0; i < default_length; i++) { - dst[i] = static_cast(default_data[i]); - } - } - return; - } - - if (loc.length == 0) return; - - auto row_start = input_offsets[row] - base_offset; - uint8_t const* src = message_data + row_start + loc.offset; - - // Copy data - for (int i = 0; i < loc.length; i++) { - dst[i] = static_cast(src[i]); - } -} - -/** - * Kernel to extract lengths from locations for prefix sum. - * Supports default values for missing fields. - */ -__global__ void extract_lengths_kernel(field_location const* locations, - int field_idx, - int num_fields, - int32_t* lengths, - int num_rows, - bool has_default = false, - int32_t default_length = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto loc = locations[row * num_fields + field_idx]; - if (loc.offset >= 0) { - lengths[row] = loc.length; - } else if (has_default) { - lengths[row] = default_length; - } else { - lengths[row] = 0; - } -} - // ============================================================================ // Repeated field extraction kernels // ============================================================================ @@ -2038,30 +1908,6 @@ __global__ void copy_scalar_string_data_kernel( // Note: make_null_mask_from_valid is defined earlier in the file (before scan_repeated_message_children_kernel) -/** - * Get the expected wire type for a given cudf type and encoding. - */ -int get_expected_wire_type(cudf::type_id type_id, int encoding) -{ - switch (type_id) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - if (encoding == spark_rapids_jni::ENC_FIXED) { - return (type_id == cudf::type_id::INT32 || type_id == cudf::type_id::UINT32) ? WT_32BIT - : WT_64BIT; - } - return WT_VARINT; - case cudf::type_id::FLOAT32: return WT_32BIT; - case cudf::type_id::FLOAT64: return WT_64BIT; - case cudf::type_id::STRING: - case cudf::type_id::LIST: return WT_LEN; - default: CUDF_FAIL("Unsupported type for protobuf decoding"); - } -} - /** * Create an all-null column of the specified type. */ @@ -2320,712 +2166,6 @@ __global__ void validate_enum_values_kernel( namespace spark_rapids_jni { -std::unique_ptr decode_protobuf_to_struct( - cudf::column_view const& binary_input, - int total_num_fields, - std::vector const& decoded_field_indices, - std::vector const& field_numbers, - std::vector const& all_types, - std::vector const& encodings, - std::vector const& is_required, - std::vector const& has_default_value, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - bool fail_on_errors) -{ - CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, - "binary_input must be a LIST column"); - cudf::lists_column_view const in_list(binary_input); - auto const child_type = in_list.child().type().id(); - CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, - "binary_input must be a LIST column"); - CUDF_EXPECTS(static_cast(all_types.size()) == total_num_fields, - "all_types size must equal total_num_fields"); - CUDF_EXPECTS(decoded_field_indices.size() == field_numbers.size(), - "decoded_field_indices and field_numbers must have the same length"); - CUDF_EXPECTS(encodings.size() == field_numbers.size(), - "encodings and field_numbers must have the same length"); - CUDF_EXPECTS(is_required.size() == field_numbers.size(), - "is_required and field_numbers must have the same length"); - CUDF_EXPECTS(has_default_value.size() == field_numbers.size(), - "has_default_value and field_numbers must have the same length"); - CUDF_EXPECTS(default_ints.size() == field_numbers.size(), - "default_ints and field_numbers must have the same length"); - CUDF_EXPECTS(default_floats.size() == field_numbers.size(), - "default_floats and field_numbers must have the same length"); - CUDF_EXPECTS(default_bools.size() == field_numbers.size(), - "default_bools and field_numbers must have the same length"); - CUDF_EXPECTS(default_strings.size() == field_numbers.size(), - "default_strings and field_numbers must have the same length"); - - auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto rows = binary_input.size(); - auto num_decoded_fields = static_cast(field_numbers.size()); - - // Handle zero-row case - if (rows == 0) { - std::vector> empty_children; - empty_children.reserve(total_num_fields); - for (auto const& dt : all_types) { - empty_children.push_back(make_empty_column_safe(dt, stream, mr)); - } - return cudf::make_structs_column( - 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); - } - - // Handle case with no fields to decode - if (num_decoded_fields == 0) { - std::vector> null_children; - null_children.reserve(total_num_fields); - for (auto const& dt : all_types) { - null_children.push_back(make_null_column(dt, rows, stream, mr)); - } - return cudf::make_structs_column( - rows, std::move(null_children), 0, rmm::device_buffer{}, stream, mr); - } - - auto d_in = cudf::column_device_view::create(binary_input, stream); - - // Prepare field descriptors for the scanning kernel - std::vector h_field_descs(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; i++) { - int schema_idx = decoded_field_indices[i]; - h_field_descs[i].field_number = field_numbers[i]; - h_field_descs[i].expected_wire_type = - get_expected_wire_type(all_types[schema_idx].id(), encodings[i]); - } - - rmm::device_uvector d_field_descs(num_decoded_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), - h_field_descs.data(), - num_decoded_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); - - // Allocate field locations array: [rows * num_decoded_fields] - rmm::device_uvector d_locations( - static_cast(rows) * num_decoded_fields, stream, mr); - - // Track errors - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - - // Check if any field has enum validation - bool has_enum_fields = std::any_of( - enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); - - // Track rows with invalid enum values (used to null entire struct row) - // This matches Spark CPU PERMISSIVE mode behavior - rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? rows : 0, stream, mr); - if (has_enum_fields) { - // Initialize all to false (no invalid enums yet) - CUDF_CUDA_TRY( - cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, rows * sizeof(bool), stream.value())); - } - - auto const threads = 256; - auto const blocks = static_cast((rows + threads - 1) / threads); - - // ========================================================================= - // Pass 1: Scan all messages and record field locations - // ========================================================================= - scan_all_fields_kernel<<>>( - *d_in, d_field_descs.data(), num_decoded_fields, d_locations.data(), d_error.data()); - - // ========================================================================= - // Check required fields (after scan pass) - // ========================================================================= - // Only check if any field is required to avoid unnecessary kernel launch - bool has_required_fields = - std::any_of(is_required.begin(), is_required.end(), [](bool b) { return b; }); - if (has_required_fields) { - // Copy is_required flags to device - // Note: std::vector is special (bitfield), so we convert to uint8_t - rmm::device_uvector d_is_required(num_decoded_fields, stream, mr); - std::vector h_is_required_vec(num_decoded_fields); - for (int i = 0; i < num_decoded_fields; i++) { - h_is_required_vec[i] = is_required[i] ? 1 : 0; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), - h_is_required_vec.data(), - num_decoded_fields * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - - check_required_fields_kernel<<>>( - d_locations.data(), d_is_required.data(), num_decoded_fields, rows, d_error.data()); - } - - // Get message data pointer and offsets for pass 2 - auto const* message_data = reinterpret_cast(in_list.child().data()); - auto const* list_offsets = in_list.offsets().data(); - // Get the base offset by copying from device to host - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - - // ========================================================================= - // Pass 2: Extract data for each field - // ========================================================================= - std::vector> all_children(total_num_fields); - int decoded_idx = 0; - - for (int schema_idx = 0; schema_idx < total_num_fields; schema_idx++) { - if (decoded_idx < num_decoded_fields && decoded_field_indices[decoded_idx] == schema_idx) { - // This field needs to be decoded - auto const dt = all_types[schema_idx]; - auto const enc = encodings[decoded_idx]; - - switch (dt.id()) { - case cudf::type_id::BOOL8: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_val = has_def ? (default_bools[decoded_idx] ? 1 : 0) : 0; - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::INT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - int32_t def_fixed = static_cast(def_int); - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_fixed); - } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } - - // Validate enum values if this is an enum field - // enum_valid_values[decoded_idx] is non-empty for enum fields - auto const& valid_enums = enum_valid_values[decoded_idx]; - if (!valid_enums.empty()) { - // Copy valid enum values to device - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Validate enum values - unknown values will null the entire row - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - rows); - } - - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::UINT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - uint32_t def_fixed = static_cast(def_int); - if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_fixed); - } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::INT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::UINT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - int64_t def_int = has_def ? default_ints[decoded_idx] : 0; - uint64_t def_fixed = static_cast(def_int); - if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_fixed); - } else { - extract_varint_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::FLOAT32: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - float def_float = has_def ? static_cast(default_floats[decoded_idx]) : 0.0f; - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_float); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::FLOAT64: { - rmm::device_uvector out(rows, stream, mr); - rmm::device_uvector valid(rows, stream, mr); - bool has_def = has_default_value[decoded_idx]; - double def_double = has_def ? default_floats[decoded_idx] : 0.0; - extract_fixed_from_locations_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - out.data(), - valid.data(), - rows, - d_error.data(), - has_def, - def_double); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - all_children[schema_idx] = - std::make_unique(dt, rows, out.release(), std::move(mask), null_count); - break; - } - - case cudf::type_id::STRING: { - // Check for default value - bool has_def = has_default_value[decoded_idx]; - auto const& def_str = default_strings[decoded_idx]; - int32_t def_len = has_def ? static_cast(def_str.size()) : 0; - - // Copy default string to device if needed - rmm::device_uvector d_default_str(def_len, stream, mr); - if (has_def && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), - def_str.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); - } - - // Extract lengths and compute output offsets via prefix sum - rmm::device_uvector lengths(rows, stream, mr); - extract_lengths_kernel<<>>(d_locations.data(), - decoded_idx, - num_decoded_fields, - lengths.data(), - rows, - has_def, - def_len); - - rmm::device_uvector output_offsets(rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - - // Get total size - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - output_offsets.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - - // Set the final offset - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Allocate and copy character data - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_varlen_data_kernel<<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - output_offsets.data(), - chars.data(), - rows, - has_def, - d_default_str.data(), - def_len); - } - - // Create validity mask (field found OR has default = valid) - rmm::device_uvector valid(rows, stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(rows), - valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( - auto row) { - return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - // Create offsets column - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - - // Create strings column using offsets + chars buffer - all_children[schema_idx] = cudf::make_strings_column( - rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); - break; - } - - case cudf::type_id::LIST: { - // For protobuf bytes: create LIST directly (optimization #2) - // Check for default value - bool has_def = has_default_value[decoded_idx]; - auto const& def_bytes = default_strings[decoded_idx]; - int32_t def_len = has_def ? static_cast(def_bytes.size()) : 0; - - // Copy default bytes to device if needed - rmm::device_uvector d_default_bytes(def_len, stream, mr); - if (has_def && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), - def_bytes.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); - } - - // Extract lengths and compute output offsets via prefix sum - rmm::device_uvector lengths(rows, stream, mr); - extract_lengths_kernel<<>>(d_locations.data(), - decoded_idx, - num_decoded_fields, - lengths.data(), - rows, - has_def, - def_len); - - rmm::device_uvector output_offsets(rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - - // Get total size - int32_t total_bytes = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, - output_offsets.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_bytes += last_len; - - // Set the final offset - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + rows, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Allocate and copy byte data directly to INT8 buffer - rmm::device_uvector child_data(total_bytes, stream, mr); - if (total_bytes > 0) { - copy_varlen_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_locations.data(), - decoded_idx, - num_decoded_fields, - output_offsets.data(), - reinterpret_cast(child_data.data()), - rows, - has_def, - d_default_bytes.data(), - def_len); - } - - // Create validity mask (field found OR has default = valid) - rmm::device_uvector valid(rows, stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(rows), - valid.begin(), - [locs = d_locations.data(), decoded_idx, num_decoded_fields, has_def] __device__( - auto row) { - return locs[row * num_decoded_fields + decoded_idx].offset >= 0 || has_def; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - // Create offsets column - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - - // Create INT8 child column directly (no intermediate strings column!) - auto child_col = std::make_unique(cudf::data_type{cudf::type_id::INT8}, - total_bytes, - child_data.release(), - rmm::device_buffer{}, - 0); - - all_children[schema_idx] = cudf::make_lists_column(rows, - std::move(offsets_col), - std::move(child_col), - null_count, - std::move(mask), - stream, - mr); - break; - } - - default: CUDF_FAIL("Unsupported output type for protobuf decoder"); - } - - decoded_idx++; - } else { - // This field is not decoded - create null column - all_children[schema_idx] = make_null_column(all_types[schema_idx], rows, stream, mr); - } - } - - // Check for errors - CUDF_CUDA_TRY(cudaPeekAtLastError()); - - // Check for any parse errors or missing required fields. - // Note: We check errors after all kernels complete rather than between kernel launches - // to avoid expensive synchronization overhead. If fail_on_errors is true and an error - // occurred, all kernels will have executed but we throw an exception here. - int h_error = 0; - CUDF_CUDA_TRY( - cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, - "Malformed protobuf message, unsupported wire type, or missing required field"); - } - - // Build the final struct - // If any rows have invalid enum values, create a null mask for the struct - // This matches Spark CPU PERMISSIVE mode: unknown enum values null the entire row - cudf::size_type struct_null_count = 0; - rmm::device_buffer struct_mask{0, stream, mr}; - - if (has_enum_fields) { - // Create struct null mask: row is valid if it has NO invalid enums - auto [mask, null_count] = cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(rows), - [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { - return !row_invalid[row]; // valid if NOT invalid - }, - stream, - mr); - struct_mask = std::move(mask); - struct_null_count = null_count; - } - - return cudf::make_structs_column( - rows, std::move(all_children), struct_null_count, std::move(struct_mask), stream, mr); -} - -// ============================================================================ -// Nested protobuf decoding implementation -// ============================================================================ namespace { @@ -3653,1261 +2793,18 @@ std::unique_ptr build_repeated_struct_column( return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); } -// ============================================================================ -// Single-Pass Decoder Implementation -// ============================================================================ +} // anonymous namespace -/** - * O(1) field lookup using direct-mapped table. - * d_field_lookup[msg_type.lookup_offset + field_number] = field_entry index, or -1. - */ -__device__ inline int sp_lookup_field( - sp_msg_type const* msg_types, - sp_field_entry const* /*field_entries*/, - int const* d_field_lookup, - int msg_type_idx, - int field_number) -{ - auto const& mt = msg_types[msg_type_idx]; - if (field_number < 0 || field_number >= mt.max_field_number) return -1; - return d_field_lookup[mt.lookup_offset + field_number]; -} - -/** - * Write an extracted scalar value to the output column. - * cur is advanced past the consumed bytes. - */ -__device__ inline void sp_write_scalar( - uint8_t const*& cur, - uint8_t const* end, - sp_field_entry const& fe, - sp_col_desc* col_descs, - int write_pos) -{ - if (fe.col_idx < 0) return; - auto& cd = col_descs[fe.col_idx]; - - if (fe.wire_type == WT_VARINT) { - uint64_t val; int vb; - if (!read_varint(cur, end, val, vb)) return; - cur += vb; - if (fe.encoding == spark_rapids_jni::ENC_ZIGZAG) { - val = (val >> 1) ^ (-(val & 1)); - } - int tid = fe.output_type_id; - if (tid == static_cast(cudf::type_id::BOOL8)) - reinterpret_cast(cd.data)[write_pos] = val ? 1 : 0; - else if (tid == static_cast(cudf::type_id::INT32)) - reinterpret_cast(cd.data)[write_pos] = static_cast(val); - else if (tid == static_cast(cudf::type_id::UINT32)) - reinterpret_cast(cd.data)[write_pos] = static_cast(val); - else if (tid == static_cast(cudf::type_id::INT64)) - reinterpret_cast(cd.data)[write_pos] = static_cast(val); - else if (tid == static_cast(cudf::type_id::UINT64)) - reinterpret_cast(cd.data)[write_pos] = val; - cd.validity[write_pos] = true; - - } else if (fe.wire_type == WT_32BIT) { - if (end - cur < 4) return; - uint32_t raw = load_le(cur); - cur += 4; - int tid = fe.output_type_id; - if (tid == static_cast(cudf::type_id::FLOAT32)) { - float f; memcpy(&f, &raw, 4); - reinterpret_cast(cd.data)[write_pos] = f; - } else { - reinterpret_cast(cd.data)[write_pos] = static_cast(raw); - } - cd.validity[write_pos] = true; - - } else if (fe.wire_type == WT_64BIT) { - if (end - cur < 8) return; - uint64_t raw = load_le(cur); - cur += 8; - int tid = fe.output_type_id; - if (tid == static_cast(cudf::type_id::FLOAT64)) { - double d; memcpy(&d, &raw, 8); - reinterpret_cast(cd.data)[write_pos] = d; - } else { - reinterpret_cast(cd.data)[write_pos] = static_cast(raw); - } - cd.validity[write_pos] = true; - - } else if (fe.wire_type == WT_LEN) { - // String / bytes - uint64_t len; int lb; - if (!read_varint(cur, end, len, lb)) return; - auto* pairs = reinterpret_cast(cd.data); - pairs[write_pos].ptr = reinterpret_cast(cur + lb); - pairs[write_pos].length = static_cast(len); - cd.validity[write_pos] = true; - cur += lb + static_cast(len); - } -} - -/** - * Count the number of packed elements in a length-delimited blob for a given element wire type. - */ -__device__ inline int sp_count_packed( - uint8_t const* data, int data_len, int elem_wire_type) -{ - if (elem_wire_type == WT_VARINT) { - int count = 0; - uint8_t const* p = data; - uint8_t const* pe = data + data_len; - while (p < pe) { - while (p < pe && (*p & 0x80u)) p++; - if (p < pe) { p++; count++; } - } - return count; - } else if (elem_wire_type == WT_32BIT) { - return data_len / 4; - } else if (elem_wire_type == WT_64BIT) { - return data_len / 8; - } - return 0; -} - -// ============================================================================ -// Pass 1: Unified Count Kernel -// Walks each message once, counting all repeated fields at all depths. -// ============================================================================ - -__global__ void sp_unified_count_kernel( - cudf::column_device_view const d_in, - sp_msg_type const* msg_types, - sp_field_entry const* fields, - int const* d_field_lookup, - int32_t* d_counts, // [num_rows * num_count_cols] - int num_count_cols, - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= d_in.size()) return; - - auto const in = cudf::detail::lists_column_device_view(d_in); - auto const& child = in.child(); - auto const base = in.offsets().element(in.offset()); - auto const start = in.offset_at(row) - base; - auto const stop = in.offset_at(row + 1) - base; - auto const* bytes = reinterpret_cast(child.data()); - - // Local counters for each counted column (repeated fields) - int32_t local_counts[SP_MAX_COUNTED]; - for (int i = 0; i < num_count_cols && i < SP_MAX_COUNTED; i++) local_counts[i] = 0; - - // Stack for nested message parsing - sp_stack_entry stack[SP_MAX_DEPTH]; - int depth = 0; - int msg_type = 0; // Root message type - - uint8_t const* cur = bytes + start; - uint8_t const* end_ptr = bytes + stop; - - while (cur < end_ptr || depth > 0) { - if (cur >= end_ptr) { - if (depth <= 0) break; - depth--; - cur = end_ptr; - end_ptr = bytes + start + stack[depth].parent_end_offset; - msg_type = stack[depth].msg_type_idx; - continue; - } - - // Read tag - uint64_t key; int kb; - if (!read_varint(cur, end_ptr, key, kb)) { atomicExch(error_flag, 1); break; } - cur += kb; - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - int fi = sp_lookup_field(msg_types, fields, d_field_lookup, msg_type, fn); - - if (fi < 0) { - // Unknown field - skip - uint8_t const* next; - if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } - cur = next; - continue; - } - - auto const& fe = fields[fi]; - - // Check for packed encoding (repeated + WT_LEN but element is not LEN) - if (fe.is_repeated && wt == WT_LEN && fe.wire_type != WT_LEN && fe.count_idx >= 0) { - uint64_t len; int lb; - if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } - int packed_len = static_cast(len); - local_counts[fe.count_idx] += sp_count_packed(cur + lb, packed_len, fe.wire_type); - cur += lb + packed_len; - continue; - } - - // Wire type mismatch - skip - if (wt != fe.wire_type) { - uint8_t const* next; - if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } - cur = next; - continue; - } - - // Nested message field - if (fe.child_msg_type >= 0 && wt == WT_LEN) { - uint64_t len; int lb; - if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } - cur += lb; - int sub_end = static_cast((cur + static_cast(len)) - (bytes + start)); - - if (fe.is_repeated && fe.count_idx >= 0) { - local_counts[fe.count_idx]++; - } - - if (depth < SP_MAX_DEPTH) { - stack[depth] = {static_cast(end_ptr - (bytes + start)), msg_type, 0}; - depth++; - end_ptr = bytes + start + sub_end; - msg_type = fe.child_msg_type; - } else { - // Max depth exceeded - skip sub-message - cur += static_cast(len); - } - continue; - } - - // Repeated non-message field - if (fe.is_repeated && fe.count_idx >= 0) { - local_counts[fe.count_idx]++; - } - - // Skip field value - uint8_t const* next; - if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } - cur = next; - } - - // Write counts to global memory - for (int i = 0; i < num_count_cols && i < SP_MAX_COUNTED; i++) { - d_counts[static_cast(row) * num_count_cols + i] = local_counts[i]; - } -} - -// ============================================================================ -// Pass 2: Unified Extract Kernel -// Walks each message once, extracting all field values at all depths. -// ============================================================================ - -__global__ void sp_unified_extract_kernel( - cudf::column_device_view const d_in, - sp_msg_type const* msg_types, - sp_field_entry const* fields, - int const* d_field_lookup, - sp_col_desc* col_descs, - int32_t const* d_row_offsets, // [num_rows * num_count_cols] - per-row write offsets - int32_t* const* d_parent_bufs, // [num_count_cols] - parent index buffers (null if not inner) - int num_count_cols, - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= d_in.size()) return; - - auto const in = cudf::detail::lists_column_device_view(d_in); - auto const& child = in.child(); - auto const base = in.offsets().element(in.offset()); - auto const start = in.offset_at(row) - base; - auto const stop = in.offset_at(row + 1) - base; - auto const* bytes = reinterpret_cast(child.data()); - - // Local write counters (initialized from row offsets) - int32_t local_counter[SP_MAX_COUNTED]; - for (int i = 0; i < num_count_cols && i < SP_MAX_COUNTED; i++) { - local_counter[i] = d_row_offsets[static_cast(row) * num_count_cols + i]; - } - - sp_stack_entry stack[SP_MAX_DEPTH]; - int depth = 0; - int msg_type = 0; - int write_base = row; // Write position for non-repeated children - - uint8_t const* cur = bytes + start; - uint8_t const* end_ptr = bytes + stop; - - while (cur < end_ptr || depth > 0) { - if (cur >= end_ptr) { - if (depth <= 0) break; - depth--; - cur = end_ptr; - end_ptr = bytes + start + stack[depth].parent_end_offset; - msg_type = stack[depth].msg_type_idx; - write_base = stack[depth].write_base; - continue; - } - - // Read tag - uint64_t key; int kb; - if (!read_varint(cur, end_ptr, key, kb)) { atomicExch(error_flag, 1); break; } - cur += kb; - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - int fi = sp_lookup_field(msg_types, fields, d_field_lookup, msg_type, fn); - - if (fi < 0) { - uint8_t const* next; - if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } - cur = next; - continue; - } - - auto const& fe = fields[fi]; - - // Packed encoding for repeated scalars - if (fe.is_repeated && wt == WT_LEN && fe.wire_type != WT_LEN && fe.count_idx >= 0) { - uint64_t len; int lb; - if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } - uint8_t const* pstart = cur + lb; - uint8_t const* pend = pstart + static_cast(len); - uint8_t const* p = pstart; - - while (p < pend) { - int pos = local_counter[fe.count_idx]++; - if (d_parent_bufs && d_parent_bufs[fe.count_idx]) { - d_parent_bufs[fe.count_idx][pos] = write_base; - } - if (fe.col_idx >= 0) { - auto& cd = col_descs[fe.col_idx]; - if (fe.wire_type == WT_VARINT) { - uint64_t val; int vb; - if (!read_varint(p, pend, val, vb)) break; - p += vb; - if (fe.encoding == spark_rapids_jni::ENC_ZIGZAG) val = (val >> 1) ^ (-(val & 1)); - int tid = fe.output_type_id; - if (tid == static_cast(cudf::type_id::BOOL8)) - reinterpret_cast(cd.data)[pos] = val ? 1 : 0; - else if (tid == static_cast(cudf::type_id::INT32)) - reinterpret_cast(cd.data)[pos] = static_cast(val); - else if (tid == static_cast(cudf::type_id::UINT32)) - reinterpret_cast(cd.data)[pos] = static_cast(val); - else if (tid == static_cast(cudf::type_id::INT64)) - reinterpret_cast(cd.data)[pos] = static_cast(val); - else if (tid == static_cast(cudf::type_id::UINT64)) - reinterpret_cast(cd.data)[pos] = val; - cd.validity[pos] = true; - } else if (fe.wire_type == WT_32BIT) { - if (pend - p < 4) break; - uint32_t raw = load_le(p); p += 4; - if (fe.output_type_id == static_cast(cudf::type_id::FLOAT32)) { - float f; memcpy(&f, &raw, 4); - reinterpret_cast(cd.data)[pos] = f; - } else { - reinterpret_cast(cd.data)[pos] = static_cast(raw); - } - cd.validity[pos] = true; - } else if (fe.wire_type == WT_64BIT) { - if (pend - p < 8) break; - uint64_t raw = load_le(p); p += 8; - if (fe.output_type_id == static_cast(cudf::type_id::FLOAT64)) { - double d; memcpy(&d, &raw, 8); - reinterpret_cast(cd.data)[pos] = d; - } else { - reinterpret_cast(cd.data)[pos] = static_cast(raw); - } - cd.validity[pos] = true; - } - } - } - cur = pend; - continue; - } - - // Wire type mismatch - skip - if (wt != fe.wire_type) { - uint8_t const* next; - if (!skip_field(cur, end_ptr, wt, next)) { atomicExch(error_flag, 1); break; } - cur = next; - continue; - } - - // Nested message - if (fe.child_msg_type >= 0 && wt == WT_LEN) { - uint64_t len; int lb; - if (!read_varint(cur, end_ptr, len, lb)) { atomicExch(error_flag, 1); break; } - cur += lb; - int sub_end = static_cast((cur + static_cast(len)) - (bytes + start)); - - int new_write_base = write_base; - if (fe.is_repeated && fe.count_idx >= 0) { - int p_pos = local_counter[fe.count_idx]++; - if (d_parent_bufs && d_parent_bufs[fe.count_idx]) { - d_parent_bufs[fe.count_idx][p_pos] = write_base; - } - new_write_base = p_pos; - } - // Set struct validity if we have a col_idx - if (fe.col_idx >= 0) { - col_descs[fe.col_idx].validity[new_write_base] = true; - } - - if (depth < SP_MAX_DEPTH) { - stack[depth] = {static_cast(end_ptr - (bytes + start)), msg_type, write_base}; - depth++; - end_ptr = bytes + start + sub_end; - msg_type = fe.child_msg_type; - write_base = new_write_base; - } else { - cur += static_cast(len); - } - continue; - } - - // Non-message field: extract value - if (fe.is_repeated && fe.count_idx >= 0) { - int pos = local_counter[fe.count_idx]++; - if (d_parent_bufs && d_parent_bufs[fe.count_idx]) { - d_parent_bufs[fe.count_idx][pos] = write_base; - } - sp_write_scalar(cur, end_ptr, fe, col_descs, pos); - } else { - // Non-repeated: write at write_base (last one wins on overwrite) - sp_write_scalar(cur, end_ptr, fe, col_descs, write_base); - } - } -} - -// ============================================================================ -// Fused prefix sum + list offsets kernels (replaces per-column thrust loops) -// ============================================================================ - -/** - * Compute exclusive prefix sums for ALL count columns in a single kernel launch. - * One thread per count column - each thread serially scans its column. - * Also writes the per-column totals and builds list offsets (num_rows+1). - */ -__global__ void sp_compute_offsets_kernel( - int32_t const* d_counts, // [num_rows × num_count_cols] row-major - int32_t* d_row_offsets, // [num_rows × num_count_cols] row-major output - int32_t* d_totals, // [num_count_cols] output - int32_t** d_list_offs_ptrs, // [num_count_cols] pointers to list offset buffers (num_rows+1 each) - int num_rows, - int num_count_cols) -{ - int c = blockIdx.x * blockDim.x + threadIdx.x; - if (c >= num_count_cols) return; - - int32_t* list_offs = d_list_offs_ptrs[c]; - int32_t sum = 0; - for (int r = 0; r < num_rows; r++) { - auto idx = static_cast(r) * num_count_cols + c; - int32_t val = d_counts[idx]; - d_row_offsets[idx] = sum; - if (list_offs) list_offs[r] = sum; - sum += val; - } - d_totals[c] = sum; - if (list_offs) list_offs[num_rows] = sum; -} - -// ============================================================================ -// Host-side helpers for single-pass decoder -// ============================================================================ - -/// Host-side column info for assembly -struct sp_host_col_info { - int schema_idx; - int col_idx; // col_idx in sp_col_desc (-1 for repeated struct containers) - int count_idx; // For repeated fields (-1 otherwise) - int parent_count_idx; // count_idx of nearest repeated ancestor (-1 for top-level) - cudf::type_id type_id; - bool is_repeated; - bool is_string; - int parent_schema_idx; // -1 for top-level -}; - -/** - * Build single-pass schema from nested_field_descriptor arrays. - * Produces message type tables, field entries, and column info. - */ -void build_single_pass_schema( - std::vector const& schema, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - // Outputs: - std::vector& msg_types, - std::vector& field_entries, - std::vector& col_infos, - std::vector& field_lookup_table, - int& num_count_cols, - int& num_output_cols) -{ - int num_fields = static_cast(schema.size()); - - // Group children by parent_idx - std::map> parent_to_children; - for (int i = 0; i < num_fields; i++) { - parent_to_children[schema[i].parent_idx].push_back(i); - } - - // Assign message type indices: root first, then each struct parent - std::map parent_to_msg_type; - int msg_type_counter = 0; - parent_to_msg_type[-1] = msg_type_counter++; - - for (int i = 0; i < num_fields; i++) { - auto type_id = schema_output_types[i].id(); - if (type_id == cudf::type_id::STRUCT && parent_to_children.count(i) > 0) { - parent_to_msg_type[i] = msg_type_counter++; - } - } - - // Assign col_idx and count_idx via DFS - int col_counter = 0; - int count_counter = 0; - std::map schema_to_col_idx; - std::map schema_to_count_idx; - - std::function assign_indices = [&](int parent_idx, int parent_count_idx) { - auto it = parent_to_children.find(parent_idx); - if (it == parent_to_children.end()) return; - - for (int si : it->second) { - auto type_id = schema_output_types[si].id(); - bool is_repeated = schema[si].is_repeated; - bool is_struct = (type_id == cudf::type_id::STRUCT); - // STRING and LIST (bytes) are both length-delimited and stored as sp_string_pair - bool is_string = (type_id == cudf::type_id::STRING || type_id == cudf::type_id::LIST); - - int my_count_idx = -1; - if (is_repeated) { - my_count_idx = count_counter++; - schema_to_count_idx[si] = my_count_idx; - } - - int my_col_idx = -1; - // All non-repeated-struct fields get a col_idx for data writing. - // Non-repeated struct containers also get one for validity tracking. - if (is_struct && !is_repeated) { - my_col_idx = col_counter++; - schema_to_col_idx[si] = my_col_idx; - } else if (!is_struct) { - my_col_idx = col_counter++; - schema_to_col_idx[si] = my_col_idx; - } - // Repeated structs: no col_idx (list offsets from count, struct from children) - - sp_host_col_info info{}; - info.schema_idx = si; - info.col_idx = my_col_idx; - info.count_idx = my_count_idx; - info.parent_count_idx = parent_count_idx; - info.type_id = type_id; - info.is_repeated = is_repeated; - info.is_string = is_string; - info.parent_schema_idx = parent_idx; - col_infos.push_back(info); - - if (is_struct) { - int child_parent_count = is_repeated ? my_count_idx : parent_count_idx; - assign_indices(si, child_parent_count); - } - } - }; - - assign_indices(-1, -1); - num_count_cols = count_counter; - num_output_cols = col_counter; - - // Build sp_msg_type and sp_field_entry arrays - msg_types.resize(msg_type_counter); - for (auto& [pidx, mt_idx] : parent_to_msg_type) { - auto it = parent_to_children.find(pidx); - if (it == parent_to_children.end()) { - msg_types[mt_idx] = {static_cast(field_entries.size()), 0, -1, 0}; - continue; - } - auto children = it->second; - std::sort(children.begin(), children.end(), [&](int a, int b) { - return schema[a].field_number < schema[b].field_number; - }); - - int first_idx = static_cast(field_entries.size()); - for (int si : children) { - sp_field_entry e{}; - e.field_number = schema[si].field_number; - e.wire_type = schema[si].wire_type; - e.output_type_id = static_cast(schema_output_types[si].id()); - e.encoding = schema[si].encoding; - e.is_repeated = schema[si].is_repeated; - e.has_default = schema[si].has_default_value; - e.default_int = e.has_default ? default_ints[si] : 0; - e.default_float = e.has_default ? default_floats[si] : 0.0; - - auto type_id = schema_output_types[si].id(); - if (type_id == cudf::type_id::STRUCT) { - auto mt_it = parent_to_msg_type.find(si); - e.child_msg_type = (mt_it != parent_to_msg_type.end()) ? mt_it->second : -1; - } else { - e.child_msg_type = -1; - } - - auto col_it = schema_to_col_idx.find(si); - e.col_idx = (col_it != schema_to_col_idx.end()) ? col_it->second : -1; - auto cnt_it = schema_to_count_idx.find(si); - e.count_idx = (cnt_it != schema_to_count_idx.end()) ? cnt_it->second : -1; - - field_entries.push_back(e); - } - msg_types[mt_idx] = {first_idx, static_cast(children.size()), -1, 0}; - } - - // Build direct-mapped field lookup table for O(1) field lookup - // For each message type, allocate a region of [0..max_field_number) in the table. - // table[offset + field_number] = index into field_entries, or -1 if not found. - int lookup_offset = 0; - for (int mt = 0; mt < msg_type_counter; mt++) { - auto& mtype = msg_types[mt]; - if (mtype.num_fields == 0) { - mtype.lookup_offset = lookup_offset; - mtype.max_field_number = 1; // at least 1 to avoid zero-size - field_lookup_table.push_back(-1); - lookup_offset += 1; - continue; - } - // Find max field number in this message type - int max_fn = 0; - for (int f = mtype.first_field_idx; f < mtype.first_field_idx + mtype.num_fields; f++) { - max_fn = std::max(max_fn, field_entries[f].field_number); - } - int table_size = max_fn + 1; - mtype.lookup_offset = lookup_offset; - mtype.max_field_number = table_size; - - // Fill with -1 (not found) - int base = static_cast(field_lookup_table.size()); - field_lookup_table.resize(base + table_size, -1); - // Set entries for known fields - for (int f = mtype.first_field_idx; f < mtype.first_field_idx + mtype.num_fields; f++) { - field_lookup_table[base + field_entries[f].field_number] = f; - } - lookup_offset += table_size; - } -} - -/** - * Recursively build a cudf column for a field in the schema. - * Returns the assembled column. - */ -std::unique_ptr sp_build_column_recursive( - std::vector const& schema, - std::vector const& schema_output_types, - std::vector const& col_infos, - std::map const& schema_idx_to_info, - // Buffers (bulk-allocated): - std::vector& col_data_ptrs, // col_idx -> data pointer in bulk buffer - std::vector& col_validity_ptrs, // col_idx -> validity pointer in bulk buffer - std::vector>& list_offsets_bufs, // count_idx -> offsets (top-level) - std::vector& inner_offs_ptrs, // count_idx -> inner offsets pointer (or null) - std::vector& inner_buf_sizes, // count_idx -> inner offsets size (0 if not inner) - std::vector& col_sizes, // col_idx -> element count - std::vector& count_totals, // count_idx -> total count - std::vector& col_elem_bytes, // col_idx -> element byte size - int schema_idx, - int num_fields, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto it = schema_idx_to_info.find(schema_idx); - if (it == schema_idx_to_info.end()) { - return make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); - } - auto const& info = *(it->second); - auto type_id = info.type_id; - bool is_repeated = info.is_repeated; - bool is_string = info.is_string; - - // Determine element count for this column - int elem_count = num_rows; - if (info.parent_count_idx >= 0) { - elem_count = count_totals[info.parent_count_idx]; - } - - if (type_id == cudf::type_id::STRUCT) { - // Find children of this struct - std::vector child_schema_indices; - for (int i = 0; i < num_fields; i++) { - if (schema[i].parent_idx == schema_idx) child_schema_indices.push_back(i); - } - - if (is_repeated) { - // LIST: build struct children, then wrap in list - int total = count_totals[info.count_idx]; - std::vector> struct_children; - for (int child_si : child_schema_indices) { - struct_children.push_back(sp_build_column_recursive( - schema, schema_output_types, col_infos, schema_idx_to_info, - col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, - col_sizes, count_totals, col_elem_bytes, child_si, num_fields, num_rows, stream, mr)); - } - auto struct_col = cudf::make_structs_column( - total, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); - - // List offsets: use inner offsets for nested repeated fields, top-level offsets otherwise - std::unique_ptr offsets_col; - if (inner_offs_ptrs[info.count_idx] != nullptr) { - int sz = inner_buf_sizes[info.count_idx]; - auto buf = rmm::device_buffer(sz * sizeof(int32_t), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(buf.data(), inner_offs_ptrs[info.count_idx], - sz * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream.value())); - offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, sz, std::move(buf), rmm::device_buffer{}, 0); - } else { - auto& offs = list_offsets_bufs[info.count_idx]; - auto const offs_size = static_cast(offs.size()); - offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, offs_size, offs.release(), rmm::device_buffer{}, 0); - } - - return cudf::make_lists_column( - elem_count, std::move(offsets_col), std::move(struct_col), - 0, rmm::device_buffer{}, stream, mr); - } else { - // Non-repeated struct - std::vector> struct_children; - for (int child_si : child_schema_indices) { - struct_children.push_back(sp_build_column_recursive( - schema, schema_output_types, col_infos, schema_idx_to_info, - col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, - col_sizes, count_totals, col_elem_bytes, child_si, num_fields, num_rows, stream, mr)); - } - // Struct validity from col_idx - int ci = info.col_idx; - if (ci >= 0 && col_validity_ptrs[ci] != nullptr && col_sizes[ci] > 0) { - auto [mask, null_count] = cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(elem_count), - [vld = col_validity_ptrs[ci]] __device__ (cudf::size_type i) { return vld[i]; }, - stream, mr); - return cudf::make_structs_column( - elem_count, std::move(struct_children), null_count, std::move(mask), stream, mr); - } - return cudf::make_structs_column( - elem_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); - } - } - - // Leaf field (scalar or string) - int ci = info.col_idx; - if (ci < 0) { - return make_null_column(schema_output_types[schema_idx], elem_count, stream, mr); - } - - // Helper lambda: build a STRING column from sp_string_pair data - auto build_string_col = [&](int col_idx, int count, bool use_validity) -> std::unique_ptr { - auto* pairs = reinterpret_cast(col_data_ptrs[col_idx]); - rmm::device_uvector str_pairs(count, stream, mr); - if (use_validity) { - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), - str_pairs.begin(), - [pairs, vld = col_validity_ptrs[col_idx]] __device__ (int i) -> cudf::strings::detail::string_index_pair { - if (vld[i]) return {pairs[i].ptr, pairs[i].length}; - return {nullptr, 0}; - }); - } else { - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), - str_pairs.begin(), - [pairs] __device__ (int i) -> cudf::strings::detail::string_index_pair { - return {pairs[i].ptr, pairs[i].length}; - }); - } - return cudf::strings::detail::make_strings_column( - str_pairs.begin(), str_pairs.end(), stream, mr); - }; - - // Helper lambda: build a LIST (bytes/binary) column from sp_string_pair data - auto build_bytes_col = [&](int col_idx, int count) -> std::unique_ptr { - auto* pairs = reinterpret_cast(col_data_ptrs[col_idx]); - auto* vld = col_validity_ptrs[col_idx]; - // Compute lengths and prefix sum -> offsets (inclusive scan then shift) - rmm::device_uvector byte_offs(count + 1, stream, mr); - if (count > 0) { - // Compute lengths directly into offsets[1..count], then exclusive_scan - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), - byte_offs.begin(), // write to [0..count-1] - [pairs, vld] __device__ (int i) -> int32_t { return vld[i] ? pairs[i].length : 0; }); - thrust::exclusive_scan(rmm::exec_policy(stream), - byte_offs.begin(), byte_offs.begin() + count, byte_offs.begin(), 0); - // Total bytes via transform_reduce (avoids D->H sync for last_off + last_len) - int32_t total_bytes = thrust::transform_reduce(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), - [pairs, vld] __device__ (int i) -> int32_t { - return vld[i] ? pairs[i].length : 0; - }, 0, cuda::std::plus{}); - CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + count, &total_bytes, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - // Copy binary data - rmm::device_uvector child_data(total_bytes > 0 ? total_bytes : 0, stream, mr); - if (total_bytes > 0) { - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), thrust::make_counting_iterator(count), - [pairs, offs = byte_offs.data(), out = child_data.data(), vld] __device__ (int i) { - if (vld[i] && pairs[i].ptr && pairs[i].length > 0) { - memcpy(out + offs[i], pairs[i].ptr, pairs[i].length); - } - }); - } - auto off_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, count + 1, byte_offs.release(), rmm::device_buffer{}, 0); - auto ch_col = std::make_unique(cudf::data_type{cudf::type_id::UINT8}, total_bytes, child_data.release(), rmm::device_buffer{}, 0); - auto [mask, null_count] = cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(count), - [v = vld] __device__ (cudf::size_type i) { return v[i]; }, stream, mr); - return cudf::make_lists_column(count, std::move(off_col), std::move(ch_col), null_count, std::move(mask), stream, mr); - } else { - // Empty bytes column - thrust::fill(rmm::exec_policy(stream), byte_offs.begin(), byte_offs.end(), 0); - auto off_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, count + 1, byte_offs.release(), rmm::device_buffer{}, 0); - auto ch_col = std::make_unique(cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); - return cudf::make_lists_column(count, std::move(off_col), std::move(ch_col), 0, rmm::device_buffer{}, stream, mr); - } - }; - - bool is_bytes = (type_id == cudf::type_id::LIST); - - if (is_repeated) { - // LIST: build child column then wrap in list - int total = count_totals[info.count_idx]; - std::unique_ptr child_col; - - if (is_bytes) { - // repeated bytes -> LIST>: build inner LIST then wrap in outer list - child_col = build_bytes_col(ci, total); - } else if (is_string) { - child_col = build_string_col(ci, total, false); - } else { - auto dt = schema_output_types[schema_idx]; - auto [mask, null_count] = cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total), - [v = col_validity_ptrs[ci]] __device__ (cudf::size_type i) { return v[i]; }, stream, mr); - // Copy data from bulk buffer into a new device_buffer for cudf::column ownership - auto data_buf = rmm::device_buffer(total * col_elem_bytes[ci], stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(data_buf.data(), col_data_ptrs[ci], - total * col_elem_bytes[ci], cudaMemcpyDeviceToDevice, stream.value())); - child_col = std::make_unique( - dt, total, std::move(data_buf), std::move(mask), null_count); - } - - // Use inner offsets if available, else use top-level list offsets - std::unique_ptr offsets_col; - if (inner_offs_ptrs[info.count_idx] != nullptr) { - int sz = inner_buf_sizes[info.count_idx]; - auto buf = rmm::device_buffer(sz * sizeof(int32_t), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(buf.data(), inner_offs_ptrs[info.count_idx], - sz * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream.value())); - offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, sz, std::move(buf), rmm::device_buffer{}, 0); - } else { - auto& offs = list_offsets_bufs[info.count_idx]; - auto const offs_size = static_cast(offs.size()); - offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, offs_size, offs.release(), rmm::device_buffer{}, 0); - } - - return cudf::make_lists_column( - elem_count, std::move(offsets_col), std::move(child_col), - 0, rmm::device_buffer{}, stream, mr); - } - - // Non-repeated leaf - if (is_bytes) { - return build_bytes_col(ci, elem_count); - } - if (is_string) { - return build_string_col(ci, elem_count, true); - } - - // Non-repeated non-string scalar - auto dt = schema_output_types[schema_idx]; - auto [mask, null_count] = cudf::detail::valid_if( - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(elem_count), - [v = col_validity_ptrs[ci]] __device__ (cudf::size_type i) { return v[i]; }, stream, mr); - // Copy data from bulk buffer into a new device_buffer for cudf::column ownership - auto data_buf = rmm::device_buffer(elem_count * col_elem_bytes[ci], stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(data_buf.data(), col_data_ptrs[ci], - elem_count * col_elem_bytes[ci], cudaMemcpyDeviceToDevice, stream.value())); - return std::make_unique( - dt, elem_count, std::move(data_buf), std::move(mask), null_count); -} - -/** - * Main single-pass decoder orchestration. - */ -std::unique_ptr decode_nested_protobuf_single_pass( - cudf::column_view const& binary_input, - std::vector const& schema, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - bool fail_on_errors) -{ - auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto num_rows = binary_input.size(); - auto num_fields = static_cast(schema.size()); - - // === Phase 1: Schema Preprocessing === - std::vector h_msg_types; - std::vector h_field_entries; - std::vector col_infos; - std::vector h_field_lookup; - int num_count_cols = 0; - int num_output_cols = 0; - - build_single_pass_schema(schema, schema_output_types, - default_ints, default_floats, default_bools, - h_msg_types, h_field_entries, col_infos, h_field_lookup, - num_count_cols, num_output_cols); - - // Check limits - if (num_count_cols > SP_MAX_COUNTED || num_output_cols > SP_MAX_OUTPUT_COLS) { - return nullptr; // Signal caller to fall back to old decoder - } - - // Copy schema to device - rmm::device_uvector d_msg_types(h_msg_types.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_msg_types.data(), h_msg_types.data(), - h_msg_types.size() * sizeof(sp_msg_type), cudaMemcpyHostToDevice, stream.value())); - - rmm::device_uvector d_field_entries(h_field_entries.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_entries.data(), h_field_entries.data(), - h_field_entries.size() * sizeof(sp_field_entry), cudaMemcpyHostToDevice, stream.value())); - - // Copy O(1) field lookup table to device - rmm::device_uvector d_field_lookup(h_field_lookup.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_lookup.data(), h_field_lookup.data(), - h_field_lookup.size() * sizeof(int), cudaMemcpyHostToDevice, stream.value())); - - auto d_in = cudf::column_device_view::create(binary_input, stream); - - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - - int const threads = 256; - int const blocks = (num_rows + threads - 1) / threads; - - // === Phase 2: Pass 1 - Count === - rmm::device_uvector d_counts( - num_count_cols > 0 ? static_cast(num_rows) * num_count_cols : 1, stream, mr); - if (num_count_cols > 0) { - CUDF_CUDA_TRY(cudaMemsetAsync(d_counts.data(), 0, - static_cast(num_rows) * num_count_cols * sizeof(int32_t), stream.value())); - } - - sp_unified_count_kernel<<>>( - *d_in, d_msg_types.data(), d_field_entries.data(), d_field_lookup.data(), - d_counts.data(), num_count_cols, d_error.data()); - - // === Phase 3: Compute Offsets and Allocate Buffers === - // Fused: compute all prefix sums + list offsets in a SINGLE kernel launch. - // Replaces ~50 syncs + ~200 kernel launches with 1 kernel + 1 sync. - rmm::device_uvector d_row_offsets( - num_count_cols > 0 ? static_cast(num_rows) * num_count_cols : 1, stream, mr); - - std::vector count_totals(num_count_cols, 0); - std::vector> list_offsets_bufs; - list_offsets_bufs.reserve(num_count_cols); - - // Pre-allocate all list offset buffers and collect device pointers - std::vector h_list_offs_ptrs(num_count_cols, nullptr); - for (int c = 0; c < num_count_cols; c++) { - list_offsets_bufs.emplace_back(num_rows + 1, stream, mr); - h_list_offs_ptrs[c] = list_offsets_bufs.back().data(); - } - - if (num_count_cols > 0) { - // Copy list offset pointers to device - rmm::device_uvector d_list_offs_ptrs(num_count_cols, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_list_offs_ptrs.data(), h_list_offs_ptrs.data(), - num_count_cols * sizeof(int32_t*), cudaMemcpyHostToDevice, stream.value())); - - // Device buffer for totals - rmm::device_uvector d_totals(num_count_cols, stream, mr); - - // Single fused kernel: prefix sums + totals + list offsets for all columns - int const off_threads = std::min(num_count_cols, 256); - int const off_blocks = (num_count_cols + off_threads - 1) / off_threads; - sp_compute_offsets_kernel<<>>( - d_counts.data(), d_row_offsets.data(), d_totals.data(), - d_list_offs_ptrs.data(), num_rows, num_count_cols); - - // Single D->H copy for all totals + single sync - CUDF_CUDA_TRY(cudaMemcpyAsync(count_totals.data(), d_totals.data(), - num_count_cols * sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - } - - // Build schema_idx -> col_info lookup - std::map schema_idx_to_info; - for (auto const& ci : col_infos) { - schema_idx_to_info[ci.schema_idx] = &ci; - } - - // Determine buffer sizes for each col_idx - std::vector col_sizes(num_output_cols, 0); - for (auto const& ci : col_infos) { - if (ci.col_idx < 0) continue; - if (ci.is_repeated && ci.count_idx >= 0) { - col_sizes[ci.col_idx] = count_totals[ci.count_idx]; - } else if (ci.parent_count_idx >= 0) { - col_sizes[ci.col_idx] = count_totals[ci.parent_count_idx]; - } else { - col_sizes[ci.col_idx] = num_rows; - } - } - - // Build col_idx -> col_info lookup (avoids O(N^2) inner loop) - std::vector col_idx_to_info(num_output_cols, nullptr); - for (auto const& c : col_infos) { - if (c.col_idx >= 0) col_idx_to_info[c.col_idx] = &c; - } - - // Compute per-column element sizes and total buffer sizes for BULK allocation. - // Replaces ~992 individual RMM allocations + ~992 memsets with 2 allocs + 2 memsets. - std::vector col_elem_bytes(num_output_cols, 0); - std::vector col_data_offsets(num_output_cols, 0); - std::vector col_validity_offsets(num_output_cols, 0); - size_t total_data_bytes = 0; - size_t total_validity_elems = 0; - - for (int ci_idx = 0; ci_idx < num_output_cols; ci_idx++) { - auto const* cinfo = col_idx_to_info[ci_idx]; - size_t eb = 0; - if (cinfo) { - auto tid = cinfo->type_id; - if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) { - eb = sizeof(sp_string_pair); - } else if (tid == cudf::type_id::STRUCT) { - eb = 0; - } else { - eb = cudf::size_of(cudf::data_type{tid}); - } - } - col_elem_bytes[ci_idx] = eb; - int32_t sz = col_sizes[ci_idx]; - // Align data offset to 16 bytes for coalesced GPU access - col_data_offsets[ci_idx] = total_data_bytes; - total_data_bytes += (sz > 0 ? sz * eb : 0); - total_data_bytes = (total_data_bytes + 15) & ~size_t{15}; // 16-byte align - - col_validity_offsets[ci_idx] = total_validity_elems; - total_validity_elems += (sz > 0 ? sz : 0); - } - - // TWO bulk allocations instead of ~992 individual ones - rmm::device_uvector bulk_data(total_data_bytes > 0 ? total_data_bytes : 1, stream, mr); - rmm::device_uvector bulk_validity(total_validity_elems > 0 ? total_validity_elems : 1, stream, mr); - - // TWO bulk memsets instead of ~992 individual ones - if (total_data_bytes > 0) { - CUDF_CUDA_TRY(cudaMemsetAsync(bulk_data.data(), 0, total_data_bytes, stream.value())); - } - if (total_validity_elems > 0) { - CUDF_CUDA_TRY(cudaMemsetAsync(bulk_validity.data(), 0, total_validity_elems * sizeof(bool), stream.value())); - } - - // Per-column pointers into the bulk buffers - std::vector col_data_ptrs(num_output_cols, nullptr); - std::vector col_validity_ptrs(num_output_cols, nullptr); - - for (int ci_idx = 0; ci_idx < num_output_cols; ci_idx++) { - int32_t sz = col_sizes[ci_idx]; - if (sz > 0) { - col_data_ptrs[ci_idx] = bulk_data.data() + col_data_offsets[ci_idx]; - col_validity_ptrs[ci_idx] = bulk_validity.data() + col_validity_offsets[ci_idx]; - } - } - - // Fill non-zero defaults (rare - proto3 defaults are all 0) - for (int ci_idx = 0; ci_idx < num_output_cols; ci_idx++) { - auto const* cinfo = col_idx_to_info[ci_idx]; - int32_t sz = col_sizes[ci_idx]; - if (!cinfo || sz <= 0) continue; - if (cinfo->type_id == cudf::type_id::STRING || - cinfo->type_id == cudf::type_id::LIST || - cinfo->type_id == cudf::type_id::STRUCT) continue; - int si = cinfo->schema_idx; - if (!schema[si].has_default_value) continue; - - auto tid = cinfo->type_id; - bool non_zero = false; - if (tid == cudf::type_id::BOOL8) non_zero = default_bools[si]; - else if (tid == cudf::type_id::FLOAT32 || tid == cudf::type_id::FLOAT64) - non_zero = (default_floats[si] != 0.0); - else non_zero = (default_ints[si] != 0); - - if (non_zero) { - thrust::fill_n(rmm::exec_policy(stream), col_validity_ptrs[ci_idx], sz, true); - auto* dp = col_data_ptrs[ci_idx]; - if (tid == cudf::type_id::BOOL8) - thrust::fill_n(rmm::exec_policy(stream), dp, sz, static_cast(1)); - else if (tid == cudf::type_id::INT32 || tid == cudf::type_id::UINT32) - thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, static_cast(default_ints[si])); - else if (tid == cudf::type_id::INT64 || tid == cudf::type_id::UINT64) - thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, default_ints[si]); - else if (tid == cudf::type_id::FLOAT32) - thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, static_cast(default_floats[si])); - else if (tid == cudf::type_id::FLOAT64) - thrust::fill_n(rmm::exec_policy(stream), reinterpret_cast(dp), sz, default_floats[si]); - } - } - - // Build device-side column descriptors (using bulk buffer pointers) - std::vector h_col_descs(num_output_cols); - for (int i = 0; i < num_output_cols; i++) { - h_col_descs[i].data = col_data_ptrs[i]; - h_col_descs[i].validity = col_validity_ptrs[i]; - } - rmm::device_uvector d_col_descs(num_output_cols, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_col_descs.data(), h_col_descs.data(), - num_output_cols * sizeof(sp_col_desc), cudaMemcpyHostToDevice, stream.value())); - - // Allocate parent index buffers for inner repeated fields - std::vector> parent_idx_storage; - std::vector h_parent_bufs(num_count_cols, nullptr); - parent_idx_storage.reserve(num_count_cols); - - for (auto const& ci : col_infos) { - if (ci.count_idx >= 0 && ci.parent_count_idx >= 0) { - // Inner repeated field: needs parent index buffer - int total = count_totals[ci.count_idx]; - parent_idx_storage.emplace_back(total > 0 ? total : 0, stream, mr); - h_parent_bufs[ci.count_idx] = parent_idx_storage.back().data(); - } - } - - rmm::device_uvector d_parent_bufs_arr(num_count_cols, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_bufs_arr.data(), h_parent_bufs.data(), - num_count_cols * sizeof(int32_t*), cudaMemcpyHostToDevice, stream.value())); - - // === Phase 4: Pass 2 - Extract === - sp_unified_extract_kernel<<>>( - *d_in, d_msg_types.data(), d_field_entries.data(), d_field_lookup.data(), - d_col_descs.data(), d_row_offsets.data(), - d_parent_bufs_arr.data(), num_count_cols, d_error.data()); - - // === Phase 5: Compute Inner List Offsets === - // Pre-compute which count columns are inner (parent_count_idx >= 0) and their sizes. - // Bulk-allocate a single buffer for all inner offsets to avoid memory pool fragmentation. - struct inner_info_t { int count_idx; int parent_count_idx; int total_child; int total_parent; }; - std::vector inner_infos; - size_t total_inner_elems = 0; - std::vector inner_buf_offsets(num_count_cols, -1); // offset into bulk inner buffer - std::vector inner_buf_sizes(num_count_cols, 0); - - for (int c = 0; c < num_count_cols; c++) { - sp_host_col_info const* cinfo_ptr = nullptr; - for (auto const& ci : col_infos) { - if (ci.count_idx == c) { cinfo_ptr = &ci; break; } - } - if (cinfo_ptr && cinfo_ptr->parent_count_idx >= 0) { - int total_child = count_totals[c]; - int total_parent = count_totals[cinfo_ptr->parent_count_idx]; - int sz = total_parent + 1; - inner_buf_offsets[c] = static_cast(total_inner_elems); - inner_buf_sizes[c] = sz; - total_inner_elems += sz; - inner_infos.push_back({c, cinfo_ptr->parent_count_idx, total_child, total_parent}); - } - } - - // Single bulk allocation for all inner offsets - rmm::device_uvector bulk_inner_offsets( - total_inner_elems > 0 ? total_inner_elems : 1, stream, mr); - if (total_inner_elems > 0) { - CUDF_CUDA_TRY(cudaMemsetAsync(bulk_inner_offsets.data(), 0, - total_inner_elems * sizeof(int32_t), stream.value())); - } - - // Compute inner offsets via lower_bound (only for non-empty inner fields) - for (auto const& ii : inner_infos) { - int c = ii.count_idx; - int32_t* out = bulk_inner_offsets.data() + inner_buf_offsets[c]; - if (ii.total_child > 0 && ii.total_parent > 0 && h_parent_bufs[c] != nullptr) { - thrust::lower_bound(rmm::exec_policy(stream), - h_parent_bufs[c], h_parent_bufs[c] + ii.total_child, - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(ii.total_parent + 1), - out); - } - // else: already zeroed by memset - } - - // Build inner offset pointer array (points into bulk buffer, no per-column allocation) - std::vector inner_offs_ptrs(num_count_cols, nullptr); - for (int c = 0; c < num_count_cols; c++) { - if (inner_buf_offsets[c] >= 0) { - inner_offs_ptrs[c] = bulk_inner_offsets.data() + inner_buf_offsets[c]; - } - } - - // === Phase 6: Column Assembly === - // Check for errors - CUDF_CUDA_TRY(cudaPeekAtLastError()); - int h_error = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); - } - - // Build top-level struct column - // For top-level repeated (LIST) columns, propagate input binary null mask. - // In protobuf: absent repeated field = [] (empty array), but null input row = null LIST. - // The old decoder did this via cudf::copy_bitmask(binary_input). We do the same here. - auto const input_null_count = binary_input.null_count(); - - std::vector> top_children; - for (int i = 0; i < num_fields; i++) { - if (schema[i].parent_idx == -1) { - auto col = sp_build_column_recursive( - schema, schema_output_types, col_infos, schema_idx_to_info, - col_data_ptrs, col_validity_ptrs, list_offsets_bufs, inner_offs_ptrs, inner_buf_sizes, - col_sizes, count_totals, col_elem_bytes, i, num_fields, num_rows, stream, mr); - - // Apply input null mask to top-level LIST columns (repeated fields) - if (input_null_count > 0 && schema[i].is_repeated) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - col->set_null_mask(std::move(null_mask), input_null_count); - } - - top_children.push_back(std::move(col)); - } - } - - return cudf::make_structs_column( - num_rows, std::move(top_children), 0, rmm::device_buffer{}, stream, mr); -} - -} // anonymous namespace - -std::unique_ptr decode_nested_protobuf_to_struct( - cudf::column_view const& binary_input, - std::vector const& schema, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - bool fail_on_errors) +std::unique_ptr decode_protobuf_to_struct( + cudf::column_view const& binary_input, + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + bool fail_on_errors) { CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, "binary_input must be a LIST column"); @@ -4950,38 +2847,6 @@ std::unique_ptr decode_nested_protobuf_to_struct( return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } - // Choose decoder based on schema complexity. - // Single-pass decoder: fewer kernel launches but expensive column assembly. - // Old per-field decoder: more kernel launches but simpler assembly. - // For large schemas (>100 output cols), the old decoder is faster because - // single-pass assembly creates hundreds of cudf columns in one batch. - // Can override with PROTOBUF_SINGLE_PASS=1 (force single-pass) or =0 (force old). - { - char const* sp_env = std::getenv("PROTOBUF_SINGLE_PASS"); - bool force_sp = (sp_env && std::string(sp_env) == "1"); - bool force_old = (sp_env && std::string(sp_env) == "0"); - - // Count output columns (non-repeated leaf fields + struct containers) - int output_col_count = 0; - for (int i = 0; i < num_fields; i++) { - if (schema_output_types[i].id() != cudf::type_id::STRUCT || !schema[i].is_repeated) { - output_col_count++; - } - } - - // Auto-select: use single-pass for small schemas, old decoder for large ones - bool use_single_pass = force_sp || (!force_old && output_col_count <= 100); - - if (use_single_pass) { - auto result = decode_nested_protobuf_single_pass( - binary_input, schema, schema_output_types, - default_ints, default_floats, default_bools, default_strings, - fail_on_errors); - if (result) return result; - // Fall through to old decoder if single-pass returns null (exceeds limits) - } - } - // Copy schema to device std::vector h_device_schema(num_fields); for (int i = 0; i < num_fields; i++) { @@ -5030,6 +2895,15 @@ std::unique_ptr decode_nested_protobuf_to_struct( rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + // Enum validation support (PERMISSIVE mode) + bool has_enum_fields = std::any_of( + enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); + rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? num_rows : 0, stream, mr); + if (has_enum_fields) { + CUDF_CUDA_TRY(cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, + num_rows * sizeof(bool), stream.value())); + } + auto const threads = 256; auto const blocks = static_cast((num_rows + threads - 1) / threads); @@ -5091,6 +2965,26 @@ std::unique_ptr decode_nested_protobuf_to_struct( scan_all_fields_kernel<<>>( *d_in, d_field_descs.data(), num_scalar, d_locations.data(), d_error.data()); + // Check required fields (after scan pass) + { + bool has_required = false; + for (int i = 0; i < num_scalar; i++) { + int si = scalar_field_indices[i]; + if (schema[si].is_required) { has_required = true; break; } + } + if (has_required) { + rmm::device_uvector d_is_required(num_scalar, stream, mr); + std::vector h_is_required(num_scalar); + for (int i = 0; i < num_scalar; i++) { + h_is_required[i] = schema[scalar_field_indices[i]].is_required ? 1 : 0; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), h_is_required.data(), + num_scalar * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + check_required_fields_kernel<<>>( + d_locations.data(), d_is_required.data(), num_scalar, num_rows, d_error.data()); + } + } + // Extract scalar values (reusing existing extraction logic) cudf::lists_column_view const in_list_view(binary_input); auto const* message_data = reinterpret_cast(in_list_view.child().data()); @@ -5137,6 +3031,18 @@ std::unique_ptr decode_nested_protobuf_to_struct( message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); } + // Enum validation: check if this INT32 field has valid enum values + if (schema_idx < static_cast(enum_valid_values.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + if (!valid_enums.empty()) { + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), + valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + validate_enum_values_kernel<<>>( + out.data(), valid.data(), d_row_has_invalid_enum.data(), + d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); + } + } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); @@ -5856,10 +3762,28 @@ std::unique_ptr decode_nested_protobuf_to_struct( CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, "Malformed protobuf message or unsupported wire type"); + CUDF_EXPECTS(h_error == 0, + "Malformed protobuf message, unsupported wire type, or missing required field"); + } + + // Build final struct with PERMISSIVE mode null mask for invalid enums + cudf::size_type struct_null_count = 0; + rmm::device_buffer struct_mask{0, stream, mr}; + + if (has_enum_fields) { + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { + return !row_invalid[row]; + }, + stream, mr); + struct_mask = std::move(mask); + struct_null_count = null_count; } - return cudf::make_structs_column(num_rows, std::move(top_level_children), 0, rmm::device_buffer{}, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(top_level_children), struct_null_count, std::move(struct_mask), stream, mr); } } // namespace spark_rapids_jni diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 30a79c95c7..e8259d5065 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -51,84 +51,27 @@ struct nested_field_descriptor { /** * Decode protobuf messages (one message per row) from a LIST column into a STRUCT - * column. + * column, with support for nested messages and repeated fields. * - * This uses a two-pass approach for efficiency: - * - Pass 1: Scan all messages once, recording (offset, length) for each requested field - * - Pass 2: Extract data in parallel using the recorded locations + * This uses a multi-pass approach: + * - Pass 1: Scan all messages, count nested elements and repeated field occurrences + * - Pass 2: Prefix sum to compute output offsets for arrays and nested structs + * - Pass 3: Extract data using pre-computed offsets + * - Pass 4: Build nested column structure * - * This is significantly faster than the per-field approach when decoding multiple fields, - * as each message is only parsed once regardless of the number of fields. + * The schema is represented as a flattened array of field descriptors with parent-child + * relationships. Top-level fields have parent_idx == -1 and depth == 0. For pure scalar + * schemas, all fields are top-level with is_repeated == false. * * Supported output child types (cudf dtypes) and corresponding protobuf field types: * - BOOL8 : protobuf `bool` (varint wire type) * - INT32 : protobuf `int32`, `sint32` (with zigzag), `fixed32`/`sfixed32` (with fixed encoding) - * - UINT32 : protobuf `uint32`, `fixed32` (with fixed encoding) * - INT64 : protobuf `int64`, `sint64` (with zigzag), `fixed64`/`sfixed64` (with fixed encoding) - * - UINT64 : protobuf `uint64`, `fixed64` (with fixed encoding) * - FLOAT32 : protobuf `float` (fixed32 wire type) * - FLOAT64 : protobuf `double` (fixed64 wire type) * - STRING : protobuf `string` (length-delimited wire type, UTF-8 text) * - LIST : protobuf `bytes` (length-delimited wire type, raw bytes as LIST) - * - * Integer handling: - * - For standard varint-encoded fields (`int32`, `int64`, `uint32`, `uint64`), use encoding=0. - * - For zigzag-encoded signed fields (`sint32`, `sint64`), use encoding=2. - * - For fixed-width fields (`fixed32`, `fixed64`, `sfixed32`, `sfixed64`), use encoding=1. - * - * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. - * - * @param binary_input LIST column, each row is one protobuf message - * @param total_num_fields Total number of fields in the output struct (including null columns) - * @param decoded_field_indices Indices into the output struct for fields that should be decoded. - * Fields not in this list will be null columns in the output. - * @param field_numbers Protobuf field numbers for decoded fields (parallel to - * decoded_field_indices) - * @param all_types Output cudf data types for ALL fields in the struct (size = total_num_fields) - * @param encodings Encoding type for each decoded field (0=default, 1=fixed, 2=zigzag) - * (parallel to decoded_field_indices) - * @param is_required Whether each decoded field is required (parallel to decoded_field_indices). - * If a required field is missing and fail_on_errors is true, an exception is - * thrown. - * @param has_default_value Whether each decoded field has a default value (parallel to - * decoded_field_indices) - * @param default_ints Default values for int/long/enum fields (parallel to decoded_field_indices) - * @param default_floats Default values for float/double fields (parallel to decoded_field_indices) - * @param default_bools Default values for bool fields (parallel to decoded_field_indices) - * @param default_strings Default values for string/bytes fields (parallel to decoded_field_indices) - * @param enum_valid_values Valid enum values for each field (parallel to decoded_field_indices). - * Empty vector means not an enum field. Non-empty vector contains the - * valid enum values. Unknown enum values will be set to null. - * @param fail_on_errors Whether to throw on malformed messages or missing required fields. - * Note: error checking is performed after all kernels complete (not between kernel launches) - * to avoid synchronization overhead. - * @return STRUCT column with total_num_fields children. Decoded fields contain the parsed data, - * other fields contain all nulls. The STRUCT itself is always non-null. - */ -std::unique_ptr decode_protobuf_to_struct( - cudf::column_view const& binary_input, - int total_num_fields, - std::vector const& decoded_field_indices, - std::vector const& field_numbers, - std::vector const& all_types, - std::vector const& encodings, - std::vector const& is_required, - std::vector const& has_default_value, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - bool fail_on_errors); - -/** - * Decode protobuf messages with support for nested messages and repeated fields. - * - * This uses a multi-pass approach: - * - Pass 1: Scan all messages, count nested elements and repeated field occurrences - * - Pass 2: Prefix sum to compute output offsets for arrays and nested structs - * - Pass 3: Extract data using pre-computed offsets - * - Pass 4: Build nested column structure + * - STRUCT : protobuf nested `message` * * @param binary_input LIST column, each row is one protobuf message * @param schema Flattened schema with parent-child relationships @@ -141,7 +84,7 @@ std::unique_ptr decode_protobuf_to_struct( * @param fail_on_errors Whether to throw on malformed data * @return STRUCT column with nested structure */ -std::unique_ptr decode_nested_protobuf_to_struct( +std::unique_ptr decode_protobuf_to_struct( cudf::column_view const& binary_input, std::vector const& schema, std::vector const& schema_output_types, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 170be5e5c1..068ccdbe18 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -23,25 +23,27 @@ /** * GPU protobuf decoding utilities. * - * This API uses a two-pass approach for efficient decoding: + * This API uses a multi-pass approach for efficient decoding: *

      - *
    • Pass 1: Scan all messages once, recording (offset, length) for each requested field
    • - *
    • Pass 2: Extract data in parallel using the recorded locations
    • + *
    • Pass 1: Scan all messages, count nested elements and repeated field occurrences
    • + *
    • Pass 2: Prefix sum to compute output offsets for arrays and nested structs
    • + *
    • Pass 3: Extract data using pre-computed offsets
    • + *
    • Pass 4: Build nested column structure
    • *
    * - * This is significantly faster than per-field parsing when decoding multiple fields, - * as each message is only parsed once regardless of the number of fields. + * The schema is represented as a flattened array of field descriptors with parent-child + * relationships. Top-level fields have parentIndices == -1 and depthLevels == 0. + * For pure scalar schemas, all fields are top-level with isRepeated == false. * - * Supported protobuf field types include scalar fields using the standard wire encodings: + * Supported protobuf field types include: *
      *
    • VARINT: {@code int32}, {@code int64}, {@code uint32}, {@code uint64}, {@code bool}
    • *
    • ZIGZAG VARINT (encoding=2): {@code sint32}, {@code sint64}
    • *
    • FIXED32 (encoding=1): {@code fixed32}, {@code sfixed32}, {@code float}
    • *
    • FIXED64 (encoding=1): {@code fixed64}, {@code sfixed64}, {@code double}
    • - *
    • LENGTH_DELIMITED: {@code string}, {@code bytes}
    • + *
    • LENGTH_DELIMITED: {@code string}, {@code bytes}, nested {@code message}
    • + *
    • Nested messages and repeated fields
    • *
    - * - * Nested messages, repeated fields, map fields, and oneof fields are out of scope for this API. */ public class Protobuf { static { @@ -52,302 +54,18 @@ public class Protobuf { public static final int ENC_FIXED = 1; public static final int ENC_ZIGZAG = 2; - /** - * Decode a protobuf message-per-row binary column into a STRUCT column. - * - * This method supports schema projection: only the fields specified in - * {@code decodedFieldIndices} will be decoded. Other fields in the output - * struct will contain all null values. - * - * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param totalNumFields Total number of fields in the output struct (including null columns). - * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. - * These must be sorted in ascending order. - * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). - * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). - * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): - * 0=default (varint), 1=fixed, 2=zigzag. - * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, - * other fields contain all nulls. - */ - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int totalNumFields, - int[] decodedFieldIndices, - int[] fieldNumbers, - int[] allTypeIds, - int[] encodings) { - return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, - allTypeIds, encodings, new boolean[decodedFieldIndices.length], true); - } - - /** - * Decode a protobuf message-per-row binary column into a STRUCT column. - * - * This method supports schema projection: only the fields specified in - * {@code decodedFieldIndices} will be decoded. Other fields in the output - * struct will contain all null values. - * - * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param totalNumFields Total number of fields in the output struct (including null columns). - * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. - * These must be sorted in ascending order. - * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). - * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). - * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): - * 0=default (varint), 1=fixed, 2=zigzag. - * @param failOnErrors if true, throw an exception on malformed protobuf messages. - * If false, return nulls for fields that cannot be parsed. - * Note: error checking is performed after all fields are processed, - * not between fields, to avoid synchronization overhead. - * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, - * other fields contain all nulls. - */ - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int totalNumFields, - int[] decodedFieldIndices, - int[] fieldNumbers, - int[] allTypeIds, - int[] encodings, - boolean failOnErrors) { - return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, - allTypeIds, encodings, new boolean[decodedFieldIndices.length], failOnErrors); - } - - /** - * Decode a protobuf message-per-row binary column into a STRUCT column. - * - * This method supports schema projection: only the fields specified in - * {@code decodedFieldIndices} will be decoded. Other fields in the output - * struct will contain all null values. - * - * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param totalNumFields Total number of fields in the output struct (including null columns). - * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. - * These must be sorted in ascending order. - * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). - * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). - * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): - * 0=default (varint), 1=fixed, 2=zigzag. - * @param isRequired Whether each decoded field is required (parallel to decodedFieldIndices). - * If a required field is missing and failOnErrors is true, an exception is thrown. - * @param failOnErrors if true, throw an exception on malformed protobuf messages or missing required fields. - * If false, return nulls for fields that cannot be parsed or are missing. - * Note: error checking is performed after all fields are processed, - * not between fields, to avoid synchronization overhead. - * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, - * other fields contain all nulls. - */ - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int totalNumFields, - int[] decodedFieldIndices, - int[] fieldNumbers, - int[] allTypeIds, - int[] encodings, - boolean[] isRequired, - boolean failOnErrors) { - int numFields = decodedFieldIndices.length; - return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, - allTypeIds, encodings, isRequired, - new boolean[numFields], // hasDefaultValue - all false - new long[numFields], // defaultInts - new double[numFields], // defaultFloats - new boolean[numFields], // defaultBools - new byte[numFields][], // defaultStrings - all null - failOnErrors); - } - - /** - * Decode a protobuf message-per-row binary column into a STRUCT column with default values support. - * - * This method supports schema projection: only the fields specified in - * {@code decodedFieldIndices} will be decoded. Other fields in the output - * struct will contain all null values. - * - * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param totalNumFields Total number of fields in the output struct (including null columns). - * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. - * These must be sorted in ascending order. - * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). - * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). - * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): - * 0=default (varint), 1=fixed, 2=zigzag. - * @param isRequired Whether each decoded field is required (parallel to decodedFieldIndices). - * If a required field is missing and failOnErrors is true, an exception is thrown. - * @param hasDefaultValue Whether each decoded field has a default value (parallel to decodedFieldIndices). - * @param defaultInts Default values for int/long/enum fields (parallel to decodedFieldIndices). - * @param defaultFloats Default values for float/double fields (parallel to decodedFieldIndices). - * @param defaultBools Default values for bool fields (parallel to decodedFieldIndices). - * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes (parallel to decodedFieldIndices). - * @param failOnErrors if true, throw an exception on malformed protobuf messages or missing required fields. - * If false, return nulls for fields that cannot be parsed or are missing. - * Note: error checking is performed after all fields are processed, - * not between fields, to avoid synchronization overhead. - * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, - * other fields contain all nulls. - */ - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int totalNumFields, - int[] decodedFieldIndices, - int[] fieldNumbers, - int[] allTypeIds, - int[] encodings, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - boolean failOnErrors) { - return decodeToStruct(binaryInput, totalNumFields, decodedFieldIndices, fieldNumbers, - allTypeIds, encodings, isRequired, hasDefaultValue, - defaultInts, defaultFloats, defaultBools, defaultStrings, - new int[decodedFieldIndices.length][], failOnErrors); - } - - /** - * Decode a protobuf message-per-row binary column into a STRUCT column with default values - * and enum validation support. - * - * This method supports schema projection: only the fields specified in - * {@code decodedFieldIndices} will be decoded. Other fields in the output - * struct will contain all null values. - * - * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param totalNumFields Total number of fields in the output struct (including null columns). - * @param decodedFieldIndices Indices into the output struct for fields that should be decoded. - * These must be sorted in ascending order. - * @param fieldNumbers Protobuf field numbers for decoded fields (parallel to decodedFieldIndices). - * @param allTypeIds cudf native type ids for ALL fields in the output struct (size = totalNumFields). - * @param encodings Encoding info for decoded fields (parallel to decodedFieldIndices): - * 0=default (varint), 1=fixed, 2=zigzag. - * @param isRequired Whether each decoded field is required (parallel to decodedFieldIndices). - * If a required field is missing and failOnErrors is true, an exception is thrown. - * @param hasDefaultValue Whether each decoded field has a default value (parallel to decodedFieldIndices). - * @param defaultInts Default values for int/long/enum fields (parallel to decodedFieldIndices). - * @param defaultFloats Default values for float/double fields (parallel to decodedFieldIndices). - * @param defaultBools Default values for bool fields (parallel to decodedFieldIndices). - * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes (parallel to decodedFieldIndices). - * @param enumValidValues Valid enum values for each field (null if not an enum). Unknown enum - * values will be set to null to match Spark CPU PERMISSIVE mode behavior. - * @param failOnErrors if true, throw an exception on malformed protobuf messages or missing required fields. - * If false, return nulls for fields that cannot be parsed or are missing. - * Note: error checking is performed after all fields are processed, - * not between fields, to avoid synchronization overhead. - * @return a cudf STRUCT column with totalNumFields children. Decoded fields contain parsed data, - * other fields contain all nulls. - */ - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int totalNumFields, - int[] decodedFieldIndices, - int[] fieldNumbers, - int[] allTypeIds, - int[] encodings, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - int[][] enumValidValues, - boolean failOnErrors) { - // Parameter validation - if (decodedFieldIndices == null || fieldNumbers == null || - allTypeIds == null || encodings == null || isRequired == null || - hasDefaultValue == null || defaultInts == null || defaultFloats == null || - defaultBools == null || defaultStrings == null || enumValidValues == null) { - throw new IllegalArgumentException("Arrays must be non-null"); - } - if (totalNumFields < 0) { - throw new IllegalArgumentException("totalNumFields must be non-negative"); - } - if (allTypeIds.length != totalNumFields) { - throw new IllegalArgumentException( - "allTypeIds length (" + allTypeIds.length + ") must equal totalNumFields (" + - totalNumFields + ")"); - } - int numDecodedFields = decodedFieldIndices.length; - if (fieldNumbers.length != numDecodedFields || - encodings.length != numDecodedFields || - isRequired.length != numDecodedFields || - hasDefaultValue.length != numDecodedFields || - defaultInts.length != numDecodedFields || - defaultFloats.length != numDecodedFields || - defaultBools.length != numDecodedFields || - defaultStrings.length != numDecodedFields || - enumValidValues.length != numDecodedFields) { - throw new IllegalArgumentException( - "All decoded field arrays must have the same length as decodedFieldIndices"); - } - - // Validate decoded field indices are in bounds and sorted - int prevIdx = -1; - for (int i = 0; i < decodedFieldIndices.length; i++) { - int idx = decodedFieldIndices[i]; - if (idx < 0 || idx >= totalNumFields) { - throw new IllegalArgumentException( - "Invalid decoded field index at position " + i + ": " + idx + - " (must be in range [0, " + totalNumFields + "))"); - } - if (idx <= prevIdx) { - throw new IllegalArgumentException( - "decodedFieldIndices must be sorted in ascending order without duplicates"); - } - prevIdx = idx; - } - - // Validate field numbers are positive - for (int i = 0; i < fieldNumbers.length; i++) { - if (fieldNumbers[i] <= 0) { - throw new IllegalArgumentException( - "Invalid field number at index " + i + ": " + fieldNumbers[i] + - " (field numbers must be positive)"); - } - } - - // Validate encoding values - for (int i = 0; i < encodings.length; i++) { - int enc = encodings[i]; - if (enc < ENC_DEFAULT || enc > ENC_ZIGZAG) { - throw new IllegalArgumentException( - "Invalid encoding value at index " + i + ": " + enc + - " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); - } - } - - long handle = decodeToStruct(binaryInput.getNativeView(), totalNumFields, - decodedFieldIndices, fieldNumbers, allTypeIds, - encodings, isRequired, hasDefaultValue, - defaultInts, defaultFloats, defaultBools, - defaultStrings, enumValidValues, failOnErrors); - return new ColumnVector(handle); - } - - private static native long decodeToStruct(long binaryInputView, - int totalNumFields, - int[] decodedFieldIndices, - int[] fieldNumbers, - int[] allTypeIds, - int[] encodings, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - int[][] enumValidValues, - boolean failOnErrors); - - // Wire type constants for nested schema + // Wire type constants public static final int WT_VARINT = 0; public static final int WT_64BIT = 1; public static final int WT_LEN = 2; public static final int WT_32BIT = 5; /** - * Decode protobuf messages with support for nested messages and repeated fields. + * Decode protobuf messages into a STRUCT column using a flattened schema representation. * - * This method uses a flattened schema representation where nested fields have parent indices - * pointing to their containing message field. + * The schema is represented as parallel arrays where nested fields have parent indices + * pointing to their containing message field. For pure scalar schemas, all fields are + * top-level (parentIndices == -1, depthLevels == 0, isRepeated == false). * * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. * @param fieldNumbers Protobuf field numbers for all fields in the flattened schema. @@ -367,22 +85,22 @@ private static native long decodeToStruct(long binaryInputView, * @param failOnErrors if true, throw an exception on malformed protobuf messages. * @return a cudf STRUCT column with nested structure. */ - public static ColumnVector decodeNestedToStruct(ColumnView binaryInput, - int[] fieldNumbers, - int[] parentIndices, - int[] depthLevels, - int[] wireTypes, - int[] outputTypeIds, - int[] encodings, - boolean[] isRepeated, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - int[][] enumValidValues, - boolean failOnErrors) { + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors) { // Parameter validation if (fieldNumbers == null || parentIndices == null || depthLevels == null || wireTypes == null || outputTypeIds == null || encodings == null || @@ -428,29 +146,29 @@ public static ColumnVector decodeNestedToStruct(ColumnView binaryInput, } } - long handle = decodeNestedToStruct(binaryInput.getNativeView(), - fieldNumbers, parentIndices, depthLevels, - wireTypes, outputTypeIds, encodings, - isRepeated, isRequired, hasDefaultValue, - defaultInts, defaultFloats, defaultBools, - defaultStrings, enumValidValues, failOnErrors); + long handle = decodeToStruct(binaryInput.getNativeView(), + fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, + isRepeated, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, failOnErrors); return new ColumnVector(handle); } - private static native long decodeNestedToStruct(long binaryInputView, - int[] fieldNumbers, - int[] parentIndices, - int[] depthLevels, - int[] wireTypes, - int[] outputTypeIds, - int[] encodings, - boolean[] isRepeated, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - int[][] enumValidValues, - boolean failOnErrors); + private static native long decodeToStruct(long binaryInputView, + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 044812f128..20bbb0cd47 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -146,12 +146,63 @@ private static Byte[] concat(Byte[]... parts) { } // ============================================================================ - // Helper methods for calling the new API + // Helper methods for calling the unified API // ============================================================================ /** - * Helper method that wraps the new API for tests that decode all fields. - * This simulates the old API behavior where all fields are decoded. + * Derive protobuf wire type from cudf type ID and encoding. + * This is only used by test helpers for scalar fields. + */ + private static int getWireType(int cudfTypeId, int encoding) { + if (cudfTypeId == DType.FLOAT32.getTypeId().getNativeId()) return Protobuf.WT_32BIT; + if (cudfTypeId == DType.FLOAT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; + if (cudfTypeId == DType.STRING.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (cudfTypeId == DType.LIST.getTypeId().getNativeId()) return Protobuf.WT_LEN; // bytes + if (cudfTypeId == DType.STRUCT.getTypeId().getNativeId()) return Protobuf.WT_LEN; + // INT32, INT64, BOOL8 - varint or fixed + if (encoding == Protobuf.ENC_FIXED) { + if (cudfTypeId == DType.INT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; + return Protobuf.WT_32BIT; + } + return Protobuf.WT_VARINT; + } + + /** + * Helper to decode all scalar fields using the unified API. + * Builds a flat schema (parentIndices=-1, depth=0, isRepeated=false for all fields). + */ + private static ColumnVector decodeScalarFields(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] parentIndices = new int[numFields]; + int[] depthLevels = new int[numFields]; + int[] wireTypes = new int[numFields]; + boolean[] isRepeated = new boolean[numFields]; + + java.util.Arrays.fill(parentIndices, -1); + // depthLevels already initialized to 0 + // isRepeated already initialized to false + for (int i = 0; i < numFields; i++) { + wireTypes[i] = getWireType(typeIds[i], encodings[i]); + } + + return Protobuf.decodeToStruct(binaryInput, fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, defaultStrings, enumValidValues, failOnErrors); + } + + /** + * Helper method that wraps the unified API for tests that decode all scalar fields. */ private static ColumnVector decodeAllFields(ColumnView binaryInput, int[] fieldNumbers, @@ -161,8 +212,7 @@ private static ColumnVector decodeAllFields(ColumnView binaryInput, } /** - * Helper method that wraps the new API for tests that decode all fields. - * This simulates the old API behavior where all fields are decoded. + * Helper method that wraps the unified API for tests that decode all scalar fields. */ private static ColumnVector decodeAllFields(ColumnView binaryInput, int[] fieldNumbers, @@ -170,14 +220,10 @@ private static ColumnVector decodeAllFields(ColumnView binaryInput, int[] encodings, boolean failOnErrors) { int numFields = fieldNumbers.length; - // When decoding all fields, decodedFieldIndices is [0, 1, 2, ..., n-1] - int[] decodedFieldIndices = new int[numFields]; - boolean[] isRequired = new boolean[numFields]; // all false by default - for (int i = 0; i < numFields; i++) { - decodedFieldIndices[i] = i; - } - return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, - fieldNumbers, typeIds, encodings, isRequired, failOnErrors); + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + new boolean[numFields], new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + new int[numFields][], failOnErrors); } /** @@ -190,12 +236,10 @@ private static ColumnVector decodeAllFieldsWithRequired(ColumnView binaryInput, boolean[] isRequired, boolean failOnErrors) { int numFields = fieldNumbers.length; - int[] decodedFieldIndices = new int[numFields]; - for (int i = 0; i < numFields; i++) { - decodedFieldIndices[i] = i; - } - return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, - fieldNumbers, typeIds, encodings, isRequired, failOnErrors); + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + isRequired, new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + new int[numFields][], failOnErrors); } // ============================================================================ @@ -320,22 +364,18 @@ void testSchemaProjection() { box(tag(3, WT_VARINT)), box(encodeVarint(42))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); - // Expected: f1=100, f2=null (not decoded), f3=42 + // Expected: f1=100, f3=42 (schema projection: only decode these two) ColumnVector expectedF1 = ColumnVector.fromBoxedLongs(100L); - ColumnVector expectedF2 = ColumnVector.fromStrings((String)null); ColumnVector expectedF3 = ColumnVector.fromBoxedInts(42); - ColumnVector expectedStruct = ColumnVector.makeStruct(expectedF1, expectedF2, expectedF3); - // Decode only fields at indices 0 and 2 (skip index 1) - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedF1, expectedF3); + // Decode only f1 (field_number=1) and f3 (field_number=3), skip f2 + // With the unified API, we only include the fields we want in the schema + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), - 3, // total fields - new int[]{0, 2}, // decode only indices 0 and 2 - new int[]{1, 3}, // field numbers for decoded fields + new int[]{1, 3}, // field numbers for f1 and f3 new int[]{DType.INT64.getTypeId().getNativeId(), - DType.STRING.getTypeId().getNativeId(), - DType.INT32.getTypeId().getNativeId()}, // types for ALL fields - new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, // encodings for decoded fields - true)) { + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT})) { AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); } } @@ -348,19 +388,14 @@ void testSchemaProjectionDecodeNone() { box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); - ColumnVector expectedF1 = ColumnVector.fromBoxedLongs((Long)null); - ColumnVector expectedF2 = ColumnVector.fromStrings((String)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expectedF1, expectedF2); - ColumnVector actualStruct = Protobuf.decodeToStruct( + // With no fields in the schema, the GPU returns an empty struct + ColumnVector actualStruct = decodeAllFields( input.getColumn(0), - 2, // total fields - new int[]{}, // decode no fields new int[]{}, // no field numbers - new int[]{DType.INT64.getTypeId().getNativeId(), - DType.STRING.getTypeId().getNativeId()}, // types for ALL fields - new int[]{}, // no encodings - true)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + new int[]{}, // no types + new int[]{})) { // no encodings + assertNotNull(actualStruct); + assertEquals(DType.STRUCT, actualStruct.getType()); } } @@ -1259,14 +1294,9 @@ private static ColumnVector decodeAllFieldsWithDefaults(ColumnView binaryInput, byte[][] defaultStrings, boolean failOnErrors) { int numFields = fieldNumbers.length; - int[] decodedFieldIndices = new int[numFields]; - for (int i = 0; i < numFields; i++) { - decodedFieldIndices[i] = i; - } - return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, - fieldNumbers, typeIds, encodings, isRequired, - hasDefaultValue, defaultInts, defaultFloats, - defaultBools, defaultStrings, failOnErrors); + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, new int[numFields][], failOnErrors); } @Test @@ -1632,7 +1662,7 @@ void testUnpackedRepeatedInt32() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { // Use the new nested API for repeated fields // Field: ids (field_number=1, parent=-1, depth=0, wire_type=VARINT, type=INT32, repeated=true) - try (ColumnVector result = Protobuf.decodeNestedToStruct( + try (ColumnVector result = Protobuf.decodeToStruct( input.getColumn(0), new int[]{1}, // fieldNumbers new int[]{-1}, // parentIndices (-1 = top level) @@ -1672,7 +1702,7 @@ void testNestedMessage() { // Flattened schema: // [0] inner: STRUCT, field_number=1, parent=-1, depth=0 // [1] inner.x: INT32, field_number=1, parent=0, depth=1 - try (ColumnVector result = Protobuf.decodeNestedToStruct( + try (ColumnVector result = Protobuf.decodeToStruct( input.getColumn(0), new int[]{1, 1}, // fieldNumbers new int[]{-1, 0}, // parentIndices @@ -1893,20 +1923,10 @@ private static ColumnVector decodeAllFieldsWithEnums(ColumnView binaryInput, int[][] enumValidValues, boolean failOnErrors) { int numFields = fieldNumbers.length; - int[] decodedFieldIndices = new int[numFields]; - boolean[] isRequired = new boolean[numFields]; - boolean[] hasDefaultValue = new boolean[numFields]; - long[] defaultInts = new long[numFields]; - double[] defaultFloats = new double[numFields]; - boolean[] defaultBools = new boolean[numFields]; - byte[][] defaultStrings = new byte[numFields][]; - for (int i = 0; i < numFields; i++) { - decodedFieldIndices[i] = i; - } - return Protobuf.decodeToStruct(binaryInput, numFields, decodedFieldIndices, - fieldNumbers, typeIds, encodings, isRequired, - hasDefaultValue, defaultInts, defaultFloats, - defaultBools, defaultStrings, enumValidValues, failOnErrors); + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + new boolean[numFields], new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + enumValidValues, failOnErrors); } @Test From aa3c85232e376cf99024799b66b49957eb14ab02 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 11 Feb 2026 19:35:54 +0800 Subject: [PATCH 021/107] Fix nested type support Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 32 + src/main/cpp/src/protobuf.cu | 1938 ++++++++++++----- src/main/cpp/src/protobuf.hpp | 4 + .../com/nvidia/spark/rapids/jni/Protobuf.java | 47 +- .../nvidia/spark/rapids/jni/ProtobufTest.java | 219 +- 5 files changed, 1648 insertions(+), 592 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index d40f4c0512..8fb6ac81a6 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -41,6 +41,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, jbooleanArray default_bools, jobjectArray default_strings, jobjectArray enum_valid_values, + jobjectArray enum_names, jboolean fail_on_errors) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); @@ -58,6 +59,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, JNI_NULL_CHECK(env, default_bools, "default_bools is null", 0); JNI_NULL_CHECK(env, default_strings, "default_strings is null", 0); JNI_NULL_CHECK(env, enum_valid_values, "enum_valid_values is null", 0); + JNI_NULL_CHECK(env, enum_names, "enum_names is null", 0); JNI_TRY { @@ -163,6 +165,35 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } } + // Convert enum names (byte[][][]). For each field: + // - null => not an enum-as-string field + // - byte[][] where each byte[] is UTF-8 enum name, ordered with enum_values[field] + std::vector>> enum_name_values; + enum_name_values.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + jobjectArray names_arr = static_cast(env->GetObjectArrayElement(enum_names, i)); + if (names_arr == nullptr) { + enum_name_values.emplace_back(); + } else { + jsize num_names = env->GetArrayLength(names_arr); + std::vector> names_for_field; + names_for_field.reserve(num_names); + for (jsize j = 0; j < num_names; ++j) { + jbyteArray name_bytes = static_cast(env->GetObjectArrayElement(names_arr, j)); + if (name_bytes == nullptr) { + names_for_field.emplace_back(); + } else { + jsize len = env->GetArrayLength(name_bytes); + jbyte* bytes = env->GetByteArrayElements(name_bytes, nullptr); + names_for_field.emplace_back(reinterpret_cast(bytes), + reinterpret_cast(bytes) + len); + env->ReleaseByteArrayElements(name_bytes, bytes, JNI_ABORT); + } + } + enum_name_values.push_back(std::move(names_for_field)); + } + } + auto result = spark_rapids_jni::decode_protobuf_to_struct( *input, schema, @@ -172,6 +203,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, default_bool_values, default_string_values, enum_values, + enum_name_values, fail_on_errors); return cudf::jni::release_as_jlong(result); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index dfd59a6ff9..841b78268c 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -56,6 +56,8 @@ constexpr int WT_32BIT = 5; namespace { +constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; + /** * Structure to record field location within a message. * offset < 0 means field was not found. @@ -1292,6 +1294,45 @@ __global__ void extract_repeated_in_nested_varint_kernel( out[idx] = static_cast(val); } +template +__global__ void extract_repeated_in_nested_fixed_kernel( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + repeated_occurrence const* occurrences, + int total_count, + OutT* out, + int* error_flag) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto const& parent_loc = parent_locs[occ.row_idx]; + + cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; + uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; + + if constexpr (WT == WT_32BIT) { + if (occ.length < 4) { + atomicExch(error_flag, 1); + out[idx] = OutT{}; + return; + } + uint32_t raw = load_le(data_ptr); + memcpy(&out[idx], &raw, sizeof(OutT)); + } else { + if (occ.length < 8) { + atomicExch(error_flag, 1); + out[idx] = OutT{}; + return; + } + uint64_t raw = load_le(data_ptr); + memcpy(&out[idx], &raw, sizeof(OutT)); + } +} + /** * Extract string values from repeated field occurrences within nested messages. */ @@ -1558,6 +1599,83 @@ inline std::unique_ptr build_repeated_msg_child_string_column( return cudf::make_strings_column(total_count, std::move(str_offsets_col), d_chars.release(), null_count, std::move(mask)); } +inline std::unique_ptr build_repeated_msg_child_bytes_column( + uint8_t const* message_data, + rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_locs, + rmm::device_uvector const& d_child_locs, + int child_idx, + int num_child_fields, + int total_count, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (total_count == 0) { + auto empty_offsets = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), rmm::device_buffer{}, 0); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync( + empty_offsets->mutable_view().data(), &zero, sizeof(int32_t), + cudaMemcpyHostToDevice, stream.value())); + auto empty_bytes = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column(0, std::move(empty_offsets), std::move(empty_bytes), + 0, rmm::device_buffer{}, stream, mr); + } + + auto const threads = 256; + auto const blocks = (total_count + threads - 1) / threads; + + rmm::device_uvector d_lengths(total_count, stream, mr); + compute_string_lengths_kernel<<>>( + d_child_locs.data(), child_idx, num_child_fields, d_lengths.data(), total_count); + + rmm::device_uvector d_offs(total_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_lengths.begin(), d_lengths.end(), + d_offs.begin(), 0); + + int32_t total_bytes = 0; + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, d_offs.data() + total_count - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, d_lengths.data() + total_count - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_bytes += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(d_offs.data() + total_count, &total_bytes, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector d_bytes(total_bytes, stream, mr); + rmm::device_uvector d_valid(total_count, stream, mr); + + if (total_bytes > 0) { + extract_repeated_msg_child_strings_kernel<<>>( + message_data, d_msg_row_offsets.data(), d_msg_locs.data(), + d_child_locs.data(), child_idx, num_child_fields, + d_offs.data(), d_bytes.data(), d_valid.data(), total_count); + } else { + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.begin(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + return child_locs[idx * ncf + ci].offset >= 0; + }); + } + + auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); + auto offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_count + 1, d_offs.release(), rmm::device_buffer{}, 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, total_bytes, + rmm::device_buffer(d_bytes.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); + return cudf::make_lists_column(total_count, std::move(offs_col), std::move(bytes_child), + null_count, std::move(mask), stream, mr); +} + /** * Kernel to compute nested struct locations from child field locations. * Replaces host-side loop that was copying data D->H, processing, then H->D. @@ -1610,6 +1728,37 @@ __global__ void compute_grandchild_parent_locations_kernel( } } +/** + * Compute virtual parent row offsets and locations for repeated message occurrences + * inside nested messages. Each occurrence becomes a virtual "row" so that + * build_nested_struct_column can recursively process the children. + */ +__global__ void compute_virtual_parents_for_nested_repeated_kernel( + repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, // original binary input list offsets + field_location const* parent_locations, // parent nested message locations + cudf::size_type* virtual_row_offsets, // output: [total_count] + field_location* virtual_parent_locs, // output: [total_count] + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto const& ploc = parent_locations[occ.row_idx]; + + virtual_row_offsets[idx] = row_list_offsets[occ.row_idx]; + + // Keep zero-length embedded messages as "present but empty". + // Protobuf allows an embedded message with length=0, which maps to a non-null + // struct with all-null children (not a null struct). + if (ploc.offset >= 0) { + virtual_parent_locs[idx] = {ploc.offset + occ.offset, occ.length}; + } else { + virtual_parent_locs[idx] = {-1, 0}; + } +} + /** * Kernel to compute message locations and row offsets from repeated occurrences. * Replaces host-side loop that processed occurrences. @@ -1955,10 +2104,10 @@ std::unique_ptr make_null_column(cudf::data_type dtype, rmm::device_buffer{}, 0); - // Empty child column - use INT8 as default element type + // Empty child column - use UINT8 for BinaryType consistency // This works because the list has 0 elements, so the child type doesn't matter for nulls auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::INT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); // All null mask auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); @@ -2164,6 +2313,88 @@ __global__ void validate_enum_values_kernel( } } +/** + * Compute output UTF-8 length for enum-as-string rows. + * Invalid/missing values produce length 0 (null row/field semantics handled by valid[] and + * row_has_invalid_enum). + */ +__global__ void compute_enum_string_lengths_kernel( + int32_t const* values, // [num_rows] enum numeric values + bool const* valid, // [num_rows] field validity + int32_t const* valid_enum_values, // sorted enum numeric values + int32_t const* enum_name_offsets, // [num_valid_values + 1] + int num_valid_values, + int32_t* lengths, // [num_rows] + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + if (!valid[row]) { + lengths[row] = 0; + return; + } + + int32_t val = values[row]; + int left = 0; + int right = num_valid_values - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + int32_t mid_val = valid_enum_values[mid]; + if (mid_val == val) { + lengths[row] = enum_name_offsets[mid + 1] - enum_name_offsets[mid]; + return; + } else if (mid_val < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // Should not happen when validate_enum_values_kernel has already run, but keep safe. + lengths[row] = 0; +} + +/** + * Copy enum-as-string UTF-8 bytes into output chars buffer using precomputed row offsets. + */ +__global__ void copy_enum_string_chars_kernel( + int32_t const* values, // [num_rows] enum numeric values + bool const* valid, // [num_rows] field validity + int32_t const* valid_enum_values, // sorted enum numeric values + int32_t const* enum_name_offsets, // [num_valid_values + 1] + uint8_t const* enum_name_chars, // concatenated enum UTF-8 names + int num_valid_values, + int32_t const* output_offsets, // [num_rows + 1] + char* out_chars, // [total_chars] + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + if (!valid[row]) return; + + int32_t val = values[row]; + int left = 0; + int right = num_valid_values - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + int32_t mid_val = valid_enum_values[mid]; + if (mid_val == val) { + int32_t src_begin = enum_name_offsets[mid]; + int32_t src_end = enum_name_offsets[mid + 1]; + int32_t dst_begin = output_offsets[row]; + for (int32_t i = 0; i < (src_end - src_begin); ++i) { + out_chars[dst_begin + i] = static_cast(enum_name_chars[src_begin + i]); + } + return; + } else if (mid_val < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } +} + namespace spark_rapids_jni { @@ -2422,6 +2653,30 @@ std::unique_ptr build_repeated_string_column( return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); } +// Forward declaration -- build_nested_struct_column is defined after build_repeated_struct_column +// but the latter's STRUCT-child case needs to call it. +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth); + /** * Build a repeated struct column (LIST of STRUCT). * This handles repeated message fields like: repeated Item items = 2; @@ -2441,6 +2696,12 @@ std::unique_ptr build_repeated_struct_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error_top, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { @@ -2658,114 +2919,47 @@ std::unique_ptr build_repeated_struct_column( d_child_locs, ci, num_child_fields, total_count, d_error, stream, mr)); break; } + case cudf::type_id::LIST: { + // bytes (BinaryType) child inside repeated message + struct_children.push_back( + build_repeated_msg_child_bytes_column( + message_data, d_msg_row_offsets, d_msg_locs, + d_child_locs, ci, num_child_fields, total_count, d_error, stream, mr)); + break; + } case cudf::type_id::STRUCT: { - // Nested struct inside repeated message - need to extract grandchild fields + // Nested struct inside repeated message - use recursive build_nested_struct_column int num_schema_fields = static_cast(h_device_schema.size()); auto grandchild_indices = find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); if (grandchild_indices.empty()) { - // No grandchildren - create empty struct column struct_children.push_back(cudf::make_structs_column( total_count, std::vector>{}, 0, rmm::device_buffer{}, stream, mr)); } else { - // Build grandchild columns - // For each occurrence, the nested struct location is in child_locs[occ * num_child_fields + ci] - // We need to scan within each nested struct for grandchild fields - - // Build grandchild field descriptors - int num_grandchildren = static_cast(grandchild_indices.size()); - std::vector h_gc_descs(num_grandchildren); - for (int gci = 0; gci < num_grandchildren; gci++) { - int gc_schema_idx = grandchild_indices[gci]; - h_gc_descs[gci].field_number = h_device_schema[gc_schema_idx].field_number; - h_gc_descs[gci].expected_wire_type = h_device_schema[gc_schema_idx].wire_type; - } - rmm::device_uvector d_gc_descs(num_grandchildren, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_gc_descs.data(), h_gc_descs.data(), - num_grandchildren * sizeof(field_descriptor), - cudaMemcpyHostToDevice, stream.value())); - - // Create nested struct locations from child_locs using GPU kernel - // This eliminates the D->H->D copy pattern (critical performance optimization) + // Compute virtual parent locations for each occurrence's nested struct child rmm::device_uvector d_nested_locs(total_count, stream, mr); - rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); - compute_nested_struct_locations_kernel<<>>( - d_child_locs.data(), d_msg_locs.data(), d_msg_row_offsets.data(), - ci, num_child_fields, d_nested_locs.data(), d_nested_row_offsets.data(), total_count); - - // Scan for grandchild fields - rmm::device_uvector d_gc_locs(total_count * num_grandchildren, stream, mr); - scan_repeated_message_children_kernel<<>>( - message_data, d_nested_row_offsets.data(), d_nested_locs.data(), total_count, - d_gc_descs.data(), num_grandchildren, d_gc_locs.data(), d_error.data()); - - // Copy grandchild locations to host - std::vector h_gc_locs(total_count * num_grandchildren); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_gc_locs.data(), d_gc_locs.data(), - h_gc_locs.size() * sizeof(field_location), - cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - - // Extract grandchild values - std::vector> grandchild_cols; - for (int gci = 0; gci < num_grandchildren; gci++) { - int gc_schema_idx = grandchild_indices[gci]; - auto const gc_dt = schema_output_types[gc_schema_idx]; - auto const gc_enc = h_device_schema[gc_schema_idx].encoding; - bool gc_has_def = h_device_schema[gc_schema_idx].has_default_value; - - switch (gc_dt.id()) { - case cudf::type_id::INT32: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); - int64_t def_val = gc_has_def ? default_ints[gc_schema_idx] : 0; - if (gc_enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_nested_row_offsets.data(), d_nested_locs.data(), - d_gc_locs.data(), gci, num_grandchildren, out.data(), valid.data(), - total_count, d_error.data(), gc_has_def, def_val); - } else { - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_nested_row_offsets.data(), d_nested_locs.data(), - d_gc_locs.data(), gci, num_grandchildren, out.data(), valid.data(), - total_count, d_error.data(), gc_has_def, def_val); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - grandchild_cols.push_back(std::make_unique( - gc_dt, total_count, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::INT64: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); - int64_t def_val = gc_has_def ? default_ints[gc_schema_idx] : 0; - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_nested_row_offsets.data(), d_nested_locs.data(), - d_gc_locs.data(), gci, num_grandchildren, out.data(), valid.data(), - total_count, d_error.data(), gc_has_def, def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - grandchild_cols.push_back(std::make_unique( - gc_dt, total_count, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::STRING: { - grandchild_cols.push_back( - build_repeated_msg_child_string_column( - message_data, d_nested_row_offsets, d_nested_locs, - d_gc_locs, gci, num_grandchildren, total_count, d_error, stream, mr)); - break; - } - default: - // Unsupported grandchild type - create null column - grandchild_cols.push_back(make_null_column(gc_dt, total_count, stream, mr)); - break; - } + rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); + { + // Convert int32_t row offsets to cudf::size_type and compute nested struct locations + rmm::device_uvector d_nested_row_offsets_i32(total_count, stream, mr); + compute_nested_struct_locations_kernel<<>>( + d_child_locs.data(), d_msg_locs.data(), d_msg_row_offsets.data(), + ci, num_child_fields, d_nested_locs.data(), d_nested_row_offsets_i32.data(), total_count); + // Add base_offset back so build_nested_struct_column can subtract it + thrust::transform(rmm::exec_policy(stream), + d_nested_row_offsets_i32.begin(), d_nested_row_offsets_i32.end(), + d_nested_row_offsets.begin(), + [base_offset] __device__(int32_t v) { + return static_cast(v) + base_offset; + }); } - - // Build the nested struct column - auto nested_struct_col = cudf::make_structs_column( - total_count, std::move(grandchild_cols), 0, rmm::device_buffer{}, stream, mr); - struct_children.push_back(std::move(nested_struct_col)); + + struct_children.push_back(build_nested_struct_column( + message_data, d_nested_row_offsets.data(), base_offset, d_nested_locs, + grandchild_indices, schema, num_schema_fields, schema_output_types, + default_ints, default_floats, default_bools, default_strings, + enum_valid_values, enum_names, d_row_has_invalid_enum, d_error_top, + total_count, stream, mr, 0)); } break; } @@ -2793,84 +2987,727 @@ std::unique_ptr build_repeated_struct_column( return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); } -} // anonymous namespace - -std::unique_ptr decode_protobuf_to_struct( - cudf::column_view const& binary_input, +/** + * Recursively build a nested STRUCT column from parent message locations. + * This supports arbitrarily deep protobuf nesting (bounded by MAX_NESTED_STRUCT_DECODE_DEPTH). + */ +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, std::vector const& schema, + int num_fields, std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, std::vector> const& default_strings, std::vector> const& enum_valid_values, - bool fail_on_errors) + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth) { - CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, - "binary_input must be a LIST column"); - cudf::lists_column_view const in_list(binary_input); - auto const child_type = in_list.child().type().id(); - CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, - "binary_input must be a LIST column"); - - auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto num_rows = binary_input.size(); - auto num_fields = static_cast(schema.size()); + CUDF_EXPECTS(depth <= MAX_NESTED_STRUCT_DECODE_DEPTH, + "Nested protobuf struct depth exceeds supported decode recursion limit"); - if (num_rows == 0 || num_fields == 0) { - // Build empty struct based on top-level fields with proper nested structure + if (num_rows == 0) { std::vector> empty_children; - for (int i = 0; i < num_fields; i++) { - if (schema[i].parent_idx == -1) { - auto field_type = schema_output_types[i]; - if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { - // Repeated message field - build empty LIST with proper struct element - rmm::device_uvector offsets(1, stream, mr); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets.data(), &zero, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); - auto empty_struct = make_empty_struct_column_with_schema( - schema, schema_output_types, i, num_fields, stream, mr); - empty_children.push_back(cudf::make_lists_column(0, std::move(offsets_col), std::move(empty_struct), - 0, rmm::device_buffer{}, stream, mr)); - } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { - // Non-repeated nested message field - empty_children.push_back(make_empty_struct_column_with_schema( - schema, schema_output_types, i, num_fields, stream, mr)); - } else { - empty_children.push_back(make_empty_column_safe(field_type, stream, mr)); - } + for (int child_schema_idx : child_field_indices) { + auto child_type = schema_output_types[child_schema_idx]; + if (child_type.id() == cudf::type_id::STRUCT) { + empty_children.push_back(make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_fields, stream, mr)); + } else { + empty_children.push_back(make_empty_column_safe(child_type, stream, mr)); } } return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } - // Copy schema to device - std::vector h_device_schema(num_fields); - for (int i = 0; i < num_fields; i++) { - h_device_schema[i] = { - schema[i].field_number, - schema[i].parent_idx, - schema[i].depth, - schema[i].wire_type, - static_cast(schema[i].output_type), - schema[i].encoding, - schema[i].is_repeated, - schema[i].is_required, - schema[i].has_default_value - }; + auto const threads = 256; + auto const blocks = (num_rows + threads - 1) / threads; + int num_child_fields = static_cast(child_field_indices.size()); + + std::vector h_child_field_descs(num_child_fields); + for (int i = 0; i < num_child_fields; i++) { + int child_idx = child_field_indices[i]; + h_child_field_descs[i].field_number = schema[child_idx].field_number; + h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; } - rmm::device_uvector d_schema(num_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), h_device_schema.data(), - num_fields * sizeof(device_nested_field_descriptor), + rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), h_child_field_descs.data(), + num_child_fields * sizeof(field_descriptor), cudaMemcpyHostToDevice, stream.value())); - auto d_in = cudf::column_device_view::create(binary_input, stream); + rmm::device_uvector d_child_locations( + static_cast(num_rows) * num_child_fields, stream, mr); + scan_nested_message_fields_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_child_field_descs.data(), num_child_fields, d_child_locations.data(), d_error.data()); - // Identify repeated and nested fields at depth 0 + std::vector> struct_children; + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = schema[child_schema_idx].encoding; + bool has_def = schema[child_schema_idx].has_default_value; + bool is_repeated = schema[child_schema_idx].is_repeated; + + if (is_repeated) { + auto elem_type_id = schema[child_schema_idx].output_type; + rmm::device_uvector d_rep_info(num_rows, stream, mr); + + std::vector rep_indices = {0}; + rmm::device_uvector d_rep_indices(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), rep_indices.data(), + sizeof(int), cudaMemcpyHostToDevice, stream.value())); + + device_nested_field_descriptor rep_desc; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + rep_desc.parent_idx = -1; + rep_desc.depth = 0; + rep_desc.encoding = 0; + rep_desc.is_required = false; + rep_desc.has_default_value = false; + + std::vector h_rep_schema = {rep_desc}; + rmm::device_uvector d_rep_schema(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), h_rep_schema.data(), + sizeof(device_nested_field_descriptor), cudaMemcpyHostToDevice, stream.value())); + + count_repeated_in_nested_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), d_error.data()); + + rmm::device_uvector d_rep_counts(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.begin(), d_rep_info.end(), + d_rep_counts.begin(), + [] __device__(repeated_field_info const& info) { return info.count; }); + int total_rep_count = thrust::reduce(rmm::exec_policy(stream), + d_rep_counts.begin(), d_rep_counts.end(), 0); + + if (total_rep_count == 0) { + rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), list_offsets_vec.begin(), list_offsets_vec.end(), 0); + auto list_offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offsets_vec.release(), + rmm::device_buffer{}, 0); + std::unique_ptr child_col; + if (elem_type_id == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + struct_children.push_back(cudf::make_lists_column( + num_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr)); + } else { + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + scan_repeated_in_nested_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), + d_rep_occs.data(), d_error.data()); + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_rep_counts.begin(), d_rep_counts.end(), + list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_rep_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + std::unique_ptr child_values; + if (elem_type_id == cudf::type_id::INT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::INT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::BOOL8) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::BOOL8}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::FLOAT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_fixed_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::FLOAT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::FLOAT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + extract_repeated_in_nested_fixed_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); + child_values = std::make_unique( + cudf::data_type{cudf::type_id::FLOAT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + } else if (elem_type_id == cudf::type_id::STRING) { + rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.begin(), d_rep_occs.end(), + d_str_lengths.begin(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); + + int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), + d_str_lengths.begin(), d_str_lengths.end(), 0); + rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_str_lengths.begin(), d_str_lengths.end(), + str_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + extract_repeated_in_nested_string_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, str_offs.data(), chars.data(), d_error.data()); + } + + auto str_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, str_offs.release(), rmm::device_buffer{}, 0); + child_values = cudf::make_strings_column(total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + } else if (elem_type_id == cudf::type_id::LIST) { + rmm::device_uvector d_len(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.begin(), d_rep_occs.end(), + d_len.begin(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); + + int32_t total_bytes = thrust::reduce(rmm::exec_policy(stream), + d_len.begin(), d_len.end(), 0); + rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_len.begin(), d_len.end(), + byte_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, &total_bytes, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector bytes(total_bytes, stream, mr); + if (total_bytes > 0) { + extract_repeated_in_nested_string_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_rep_occs.data(), total_rep_count, byte_offs.data(), bytes.data(), d_error.data()); + } + + auto offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, byte_offs.release(), rmm::device_buffer{}, 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, total_bytes, + rmm::device_buffer(bytes.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); + child_values = cudf::make_lists_column( + total_rep_count, std::move(offs_col), std::move(bytes_child), 0, rmm::device_buffer{}, stream, mr); + } else if (elem_type_id == cudf::type_id::STRUCT) { + // Repeated message field (ArrayType(StructType)) inside nested message. + // Build virtual parent info for each occurrence so we can recursively decode children. + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + child_values = cudf::make_structs_column( + total_rep_count, std::vector>{}, + 0, rmm::device_buffer{}, stream, mr); + } else { + rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); + rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); + auto const rep_blk = (total_rep_count + 255) / 256; + compute_virtual_parents_for_nested_repeated_kernel<<>>( + d_rep_occs.data(), list_offsets, d_parent_locs.data(), + d_virtual_row_offsets.data(), d_virtual_parent_locs.data(), total_rep_count); + + child_values = build_nested_struct_column( + message_data, d_virtual_row_offsets.data(), base_offset, d_virtual_parent_locs, + gc_indices, schema, num_fields, schema_output_types, default_ints, default_floats, + default_bools, default_strings, enum_valid_values, enum_names, + d_row_has_invalid_enum, d_error, total_rep_count, stream, mr, depth + 1); + } + } else { + child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + + auto list_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); + struct_children.push_back(cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}, stream, mr)); + } + continue; + } + + switch (dt.id()) { + case cudf::type_id::BOOL8: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_val); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::UINT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::INT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_ZIGZAG) { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } else if (enc == spark_rapids_jni::ENC_FIXED) { + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } else { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::UINT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_float); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::FLOAT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + double def_double = has_def ? default_floats[child_schema_idx] : 0.0; + extract_nested_fixed_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_double); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + struct_children.push_back(std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count)); + break; + } + case cudf::type_id::STRING: { + if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + extract_nested_varint_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), + num_rows, d_error.data(), has_def, def_int); + + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size())) { + auto const& valid_enums = enum_valid_values[child_schema_idx]; + auto const& enum_name_bytes = enum_names[child_schema_idx]; + if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), + valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + validate_enum_values_kernel<<>>( + out.data(), valid.data(), d_row_has_invalid_enum.data(), + d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int32_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + h_name_offsets[k + 1] = total_name_chars; + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.begin(), name.end(), h_name_chars.begin() + cursor); + cursor += static_cast(name.size()); + } + } + + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), h_name_chars.data(), + total_name_chars * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + compute_enum_string_lengths_kernel<<>>( + out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), + static_cast(valid_enums.size()), lengths.data(), num_rows); + + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); + + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), + d_name_chars.data(), static_cast(valid_enums.size()), + output_offsets.data(), chars.data(), num_rows); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + struct_children.push_back(cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); + } else { + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + } + } else { + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + } + } else { + bool has_def_str = has_def; + auto const& def_str = default_strings[child_schema_idx]; + int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; + + rmm::device_uvector d_default_str(def_len, stream, mr); + if (has_def_str && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, + cudaMemcpyHostToDevice, stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + extract_nested_lengths_kernel<<>>( + d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, + lengths.data(), num_rows, has_def_str, def_len); + + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); + + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_nested_varlen_data_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, output_offsets.data(), + chars.data(), num_rows, has_def_str, d_default_str.data(), def_len); + } + + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, num_child_fields, has_def_str] __device__(auto row) { + return (plocs[row].offset >= 0 && + flocs[row * num_child_fields + ci].offset >= 0) || has_def_str; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + struct_children.push_back(cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); + } + break; + } + case cudf::type_id::LIST: { + // bytes (BinaryType) represented as LIST + bool has_def_bytes = has_def; + auto const& def_bytes = default_strings[child_schema_idx]; + int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; + + rmm::device_uvector d_default_bytes(def_len, stream, mr); + if (has_def_bytes && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), def_bytes.data(), def_len, + cudaMemcpyHostToDevice, stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + extract_nested_lengths_kernel<<>>( + d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, + lengths.data(), num_rows, has_def_bytes, def_len); + + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); + + int32_t total_bytes = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_bytes += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_bytes, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector bytes_data(total_bytes, stream, mr); + if (total_bytes > 0) { + copy_nested_varlen_data_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), + d_child_locations.data(), ci, num_child_fields, output_offsets.data(), + bytes_data.data(), num_rows, has_def_bytes, d_default_bytes.data(), def_len); + } + + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, num_child_fields, has_def_bytes] __device__(auto row) { + return (plocs[row].offset >= 0 && + flocs[row * num_child_fields + ci].offset >= 0) || has_def_bytes; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, total_bytes, + rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); + struct_children.push_back(cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask), stream, mr)); + break; + } + case cudf::type_id::STRUCT: { + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + break; + } + rmm::device_uvector d_gc_parent(num_rows, stream, mr); + compute_grandchild_parent_locations_kernel<<>>( + d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, + d_gc_parent.data(), num_rows); + struct_children.push_back(build_nested_struct_column( + message_data, list_offsets, base_offset, d_gc_parent, gc_indices, + schema, num_fields, schema_output_types, default_ints, default_floats, default_bools, + default_strings, enum_valid_values, enum_names, d_row_has_invalid_enum, d_error, + num_rows, stream, mr, depth + 1)); + break; + } + default: + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + break; + } + } + + rmm::device_uvector struct_valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + struct_valid.begin(), + [plocs = d_parent_locs.data()] __device__(auto row) { + return plocs[row].offset >= 0; + }); + auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); +} + +} // anonymous namespace + +std::unique_ptr decode_protobuf_to_struct( + cudf::column_view const& binary_input, + std::vector const& schema, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + bool fail_on_errors) +{ + CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, + "binary_input must be a LIST column"); + cudf::lists_column_view const in_list(binary_input); + auto const child_type = in_list.child().type().id(); + CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, + "binary_input must be a LIST column"); + + auto const stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + auto num_rows = binary_input.size(); + auto num_fields = static_cast(schema.size()); + + if (num_rows == 0 || num_fields == 0) { + // Build empty struct based on top-level fields with proper nested structure + std::vector> empty_children; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { + auto field_type = schema_output_types[i]; + if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { + // Repeated message field - build empty LIST with proper struct element + rmm::device_uvector offsets(1, stream, mr); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets.data(), &zero, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); + auto empty_struct = make_empty_struct_column_with_schema( + schema, schema_output_types, i, num_fields, stream, mr); + empty_children.push_back(cudf::make_lists_column(0, std::move(offsets_col), std::move(empty_struct), + 0, rmm::device_buffer{}, stream, mr)); + } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { + // Non-repeated nested message field + empty_children.push_back(make_empty_struct_column_with_schema( + schema, schema_output_types, i, num_fields, stream, mr)); + } else { + empty_children.push_back(make_empty_column_safe(field_type, stream, mr)); + } + } + } + return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + // Copy schema to device + std::vector h_device_schema(num_fields); + for (int i = 0; i < num_fields; i++) { + h_device_schema[i] = { + schema[i].field_number, + schema[i].parent_idx, + schema[i].depth, + schema[i].wire_type, + static_cast(schema[i].output_type), + schema[i].encoding, + schema[i].is_repeated, + schema[i].is_required, + schema[i].has_default_value + }; + } + + rmm::device_uvector d_schema(num_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), h_device_schema.data(), + num_fields * sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, stream.value())); + + auto d_in = cudf::column_device_view::create(binary_input, stream); + + // Identify repeated and nested fields at depth 0 std::vector repeated_field_indices; std::vector nested_field_indices; std::vector scalar_field_indices; @@ -3048,6 +3885,24 @@ std::unique_ptr decode_protobuf_to_struct( dt, num_rows, out.release(), std::move(mask), null_count); break; } + case cudf::type_id::UINT32: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); + break; + } case cudf::type_id::INT64: { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); @@ -3070,6 +3925,24 @@ std::unique_ptr decode_protobuf_to_struct( dt, num_rows, out.release(), std::move(mask), null_count); break; } + case cudf::type_id::UINT64: { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[schema_idx] : 0; + if (enc == spark_rapids_jni::ENC_FIXED) { + extract_fixed_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, static_cast(def_int)); + } else { + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + column_map[schema_idx] = std::make_unique( + dt, num_rows, out.release(), std::move(mask), null_count); + break; + } case cudf::type_id::FLOAT32: { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); @@ -3095,63 +3968,219 @@ std::unique_ptr decode_protobuf_to_struct( break; } case cudf::type_id::STRING: { - // Extract top-level STRING scalar field - bool has_def_str = has_def && !default_strings[schema_idx].empty(); - auto const& def_str = default_strings[schema_idx]; - int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; + if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + // ENUM-as-string path: + // 1. Decode enum numeric value as INT32 varint. + // 2. Validate against enum_valid_values. + // 3. Convert INT32 -> UTF-8 enum name bytes. + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid(num_rows, stream, mr); + int64_t def_int = has_def ? default_ints[schema_idx] : 0; + extract_varint_from_locations_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); - rmm::device_uvector d_default_str(def_len, stream, mr); - if (has_def_str && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, + if (schema_idx < static_cast(enum_valid_values.size()) && + schema_idx < static_cast(enum_names.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + auto const& enum_name_bytes = enum_names[schema_idx]; + if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { + // Validate enum numeric values first. + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), + valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + validate_enum_values_kernel<<>>( + out.data(), valid.data(), d_row_has_invalid_enum.data(), + d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); + + // Build flattened enum-name chars and offsets on host, then copy to device. + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int32_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + h_name_offsets[k + 1] = total_name_chars; + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.begin(), name.end(), h_name_chars.begin() + cursor); + cursor += static_cast(name.size()); + } + } + + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), h_name_chars.data(), + total_name_chars * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + } + + // Compute output UTF-8 lengths + rmm::device_uvector lengths(num_rows, stream, mr); + compute_enum_string_lengths_kernel<<>>( + out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), + static_cast(valid_enums.size()), lengths.data(), num_rows); + + // Prefix sum for string offsets + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); + + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), + d_name_chars.data(), static_cast(valid_enums.size()), + output_offsets.data(), chars.data(), num_rows); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + column_map[schema_idx] = cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + } else { + // Missing enum metadata for enum-as-string field; mark as decode error. + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); + } + } else { + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); + } + } else { + // Regular protobuf STRING (length-delimited) + bool has_def_str = has_def; + auto const& def_str = default_strings[schema_idx]; + int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; + + rmm::device_uvector d_default_str(def_len, stream, mr); + if (has_def_str && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, + cudaMemcpyHostToDevice, stream.value())); + } + + // Extract string lengths + rmm::device_uvector lengths(num_rows, stream, mr); + extract_scalar_string_lengths_kernel<<>>( + d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_str, def_len); + + // Compute offsets via prefix sum + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), + output_offsets.begin(), 0); + + int32_t total_chars = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + // Copy string data + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_scalar_string_data_kernel<<>>( + message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, + output_offsets.data(), chars.data(), num_rows, has_def_str, + d_default_str.data(), def_len); + } + + // Build validity mask + rmm::device_uvector valid(num_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_str; + }); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + column_map[schema_idx] = cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + } + break; + } + case cudf::type_id::LIST: { + // bytes (BinaryType) represented as LIST + bool has_def_bytes = has_def; + auto const& def_bytes = default_strings[schema_idx]; + int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; + + rmm::device_uvector d_default_bytes(def_len, stream, mr); + if (has_def_bytes && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), def_bytes.data(), def_len, cudaMemcpyHostToDevice, stream.value())); } - // Extract string lengths rmm::device_uvector lengths(num_rows, stream, mr); extract_scalar_string_lengths_kernel<<>>( - d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_str, def_len); + d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_bytes, def_len); - // Compute offsets via prefix sum rmm::device_uvector output_offsets(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, + int32_t total_bytes = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, output_offsets.data() + num_rows - 1, sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); int32_t last_len = 0; CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, + total_bytes += last_len; + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_bytes, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - // Copy string data - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { + rmm::device_uvector bytes_data(total_bytes, stream, mr); + if (total_bytes > 0) { copy_scalar_string_data_kernel<<>>( message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - output_offsets.data(), chars.data(), num_rows, has_def_str, - d_default_str.data(), def_len); + output_offsets.data(), bytes_data.data(), num_rows, has_def_bytes, + d_default_bytes.data(), def_len); } - // Build validity mask rmm::device_uvector valid(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), valid.begin(), - [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_str; + [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__(auto row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); - column_map[schema_idx] = cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), + rmm::device_buffer{}, 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, total_bytes, + rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask), stream, mr); break; } default: @@ -3242,11 +4271,138 @@ std::unique_ptr decode_protobuf_to_struct( binary_input, h_device_schema[schema_idx], field_info, d_occurrences, total_count, num_rows, stream, mr); break; - case cudf::type_id::STRING: - column_map[schema_idx] = build_repeated_string_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, false, stream, mr); + case cudf::type_id::STRING: { + auto enc = schema[schema_idx].encoding; + if (enc == spark_rapids_jni::ENC_ENUM_STRING && + schema_idx < static_cast(enum_valid_values.size()) && + schema_idx < static_cast(enum_names.size()) && + !enum_valid_values[schema_idx].empty() && + enum_valid_values[schema_idx].size() == enum_names[schema_idx].size()) { + // Repeated enum-as-string: extract varints, then convert to strings. + auto const& valid_enums = enum_valid_values[schema_idx]; + auto const& name_bytes = enum_names[schema_idx]; + + cudf::lists_column_view const in_lv(binary_input); + auto const* msg_data = reinterpret_cast(in_lv.child().data()); + auto const* loffs = in_lv.offsets().data(); + + cudf::size_type boff = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&boff, loffs, sizeof(cudf::size_type), + cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // 1. Extract enum integer values from occurrences + rmm::device_uvector enum_ints(total_count, stream, mr); + auto const rep_blocks = static_cast((total_count + 255) / 256); + extract_repeated_varint_kernel<<>>( + msg_data, loffs, boff, d_occurrences.data(), total_count, + enum_ints.data(), d_error.data()); + + // 2. Build device-side enum lookup tables + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), + valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int32_t total_name_chars = 0; + for (size_t k = 0; k < name_bytes.size(); ++k) { + total_name_chars += static_cast(name_bytes[k].size()); + h_name_offsets[k + 1] = total_name_chars; + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& nm : name_bytes) { + if (!nm.empty()) { + std::copy(nm.begin(), nm.end(), h_name_chars.begin() + cursor); + cursor += static_cast(nm.size()); + } + } + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), h_name_chars.data(), + total_name_chars * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + } + + // 3. Validate enum values (sets row_has_invalid_enum for PERMISSIVE mode). + // We also need per-element validity for string building. + rmm::device_uvector elem_valid(total_count, stream, mr); + thrust::fill(rmm::exec_policy(stream), elem_valid.begin(), elem_valid.end(), true); + // validate_enum_values_kernel works on per-row basis; here we need per-element. + // Binary-search each element inline via the lengths kernel below. + + // 4. Compute per-element string lengths + rmm::device_uvector elem_lengths(total_count, stream, mr); + compute_enum_string_lengths_kernel<<>>( + enum_ints.data(), elem_valid.data(), d_valid_enums.data(), d_name_offsets.data(), + static_cast(valid_enums.size()), elem_lengths.data(), total_count); + + // 5. Build string offsets + rmm::device_uvector str_offsets(total_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), elem_lengths.begin(), elem_lengths.end(), + str_offsets.begin(), 0); + + int32_t total_chars = 0; + if (total_count > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, str_offsets.data() + total_count - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, elem_lengths.data() + total_count - 1, + sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + total_chars += last_len; + } + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, &total_chars, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + // 6. Copy string chars + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + enum_ints.data(), elem_valid.data(), d_valid_enums.data(), d_name_offsets.data(), + d_name_chars.data(), static_cast(valid_enums.size()), + str_offsets.data(), chars.data(), total_count); + } + + // 7. Assemble LIST column + auto str_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), + rmm::device_buffer{}, 0); + auto child_col = cudf::make_strings_column( + total_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + + // Build list offsets from per-row counts + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_field_counts.begin(), d_field_counts.end(), + list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, + sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + auto list_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), + rmm::device_buffer{}, 0); + + auto input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_col), + input_null_count, std::move(null_mask), stream, mr); + } else { + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_col), + 0, rmm::device_buffer{}, stream, mr); + } + } else { + column_map[schema_idx] = build_repeated_string_column( + binary_input, h_device_schema[schema_idx], field_info, d_occurrences, + total_count, num_rows, false, stream, mr); + } break; + } case cudf::type_id::LIST: // bytes as LIST column_map[schema_idx] = build_repeated_string_column( binary_input, h_device_schema[schema_idx], field_info, d_occurrences, @@ -3263,6 +4419,8 @@ std::unique_ptr decode_protobuf_to_struct( binary_input, h_device_schema[schema_idx], field_info, d_occurrences, total_count, num_rows, h_device_schema, child_field_indices, schema_output_types, default_ints, default_floats, default_bools, + default_strings, schema, enum_valid_values, enum_names, + d_row_has_invalid_enum, d_error, stream, mr); } break; @@ -3334,21 +4492,6 @@ std::unique_ptr decode_protobuf_to_struct( continue; } - int num_child_fields = static_cast(child_field_indices.size()); - - // Build field descriptors for child fields - std::vector h_child_field_descs(num_child_fields); - for (int i = 0; i < num_child_fields; i++) { - int child_idx = child_field_indices[i]; - h_child_field_descs[i].field_number = schema[child_idx].field_number; - h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; - } - - rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), h_child_field_descs.data(), - num_child_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, stream.value())); - // Prepare parent locations for this nested field rmm::device_uvector d_parent_locs(num_rows, stream, mr); std::vector h_parent_locs(num_rows); @@ -3358,387 +4501,12 @@ std::unique_ptr decode_protobuf_to_struct( CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), h_parent_locs.data(), num_rows * sizeof(field_location), cudaMemcpyHostToDevice, stream.value())); - - // Scan for child fields within nested messages - rmm::device_uvector d_child_locations( - static_cast(num_rows) * num_child_fields, stream, mr); - - scan_nested_message_fields_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_child_field_descs.data(), num_child_fields, d_child_locations.data(), d_error.data()); - - // Extract child field values - std::vector> struct_children; - for (int ci = 0; ci < num_child_fields; ci++) { - int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; - auto const enc = schema[child_schema_idx].encoding; - bool has_def = schema[child_schema_idx].has_default_value; - bool is_repeated = schema[child_schema_idx].is_repeated; - - // Check if this is a repeated field (ArrayType) - if (is_repeated) { - // Handle repeated field inside nested message - auto elem_type_id = schema[child_schema_idx].output_type; - - // Copy child locations to host - std::vector h_rep_parent_locs(num_rows); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_rep_parent_locs.data(), d_parent_locs.data(), - num_rows * sizeof(field_location), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - - // Count repeated field occurrences for each row - rmm::device_uvector d_rep_info(num_rows, stream, mr); - - std::vector rep_indices = {0}; - rmm::device_uvector d_rep_indices(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), rep_indices.data(), - sizeof(int), cudaMemcpyHostToDevice, stream.value())); - - device_nested_field_descriptor rep_desc; - rep_desc.field_number = schema[child_schema_idx].field_number; - rep_desc.wire_type = schema[child_schema_idx].wire_type; - rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); - rep_desc.is_repeated = true; - - std::vector h_rep_schema = {rep_desc}; - rmm::device_uvector d_rep_schema(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), h_rep_schema.data(), - sizeof(device_nested_field_descriptor), cudaMemcpyHostToDevice, stream.value())); - - count_repeated_in_nested_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), d_error.data()); - - // Compute total_rep_count on GPU using thrust::reduce (performance fix!) - // Extract counts from repeated_field_info on device - rmm::device_uvector d_rep_counts(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), d_rep_info.end(), - d_rep_counts.begin(), - [] __device__(repeated_field_info const& info) { return info.count; }); - - int total_rep_count = thrust::reduce(rmm::exec_policy(stream), - d_rep_counts.begin(), d_rep_counts.end(), 0); - - if (total_rep_count == 0) { - rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), list_offsets_vec.begin(), list_offsets_vec.end(), 0); - auto list_offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offsets_vec.release(), rmm::device_buffer{}, 0); - auto child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - struct_children.push_back(cudf::make_lists_column( - num_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr)); - } else { - rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), - d_rep_occs.data(), d_error.data()); - - // Compute list offsets on GPU using exclusive_scan (performance fix!) - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_rep_counts.begin(), d_rep_counts.end(), - list_offs.begin(), 0); - // Set last element - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_rep_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - - std::unique_ptr child_values; - if (elem_type_id == cudf::type_id::INT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); - } else if (elem_type_id == cudf::type_id::INT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); - } else if (elem_type_id == cudf::type_id::STRING) { - // Compute string offsets on GPU using thrust (performance fix!) - // Extract lengths from occurrences on device - rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.begin(), d_rep_occs.end(), - d_str_lengths.begin(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - // Compute total chars and offsets - int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), - d_str_lengths.begin(), d_str_lengths.end(), 0); - - rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_str_lengths.begin(), d_str_lengths.end(), - str_offs.begin(), 0); - // Set last element - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - extract_repeated_in_nested_string_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, str_offs.data(), chars.data(), d_error.data()); - } - - auto str_offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, str_offs.release(), rmm::device_buffer{}, 0); - child_values = cudf::make_strings_column(total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); - } else { - child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - - auto list_offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); - struct_children.push_back(cudf::make_lists_column( - num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}, stream, mr)); - } - continue; // Skip the switch statement below - } - - switch (dt.id()) { - case cudf::type_id::BOOL8: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); - int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::INT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, static_cast(def_int)); - } else { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::INT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); - } else { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::FLOAT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); - float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_float); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::FLOAT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); - double def_double = has_def ? default_floats[child_schema_idx] : 0.0; - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_double); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); - break; - } - case cudf::type_id::STRING: { - bool has_def_str = has_def && !default_strings[child_schema_idx].empty(); - auto const& def_str = default_strings[child_schema_idx]; - int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; - - rmm::device_uvector d_default_str(def_len, stream, mr); - if (has_def_str && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, - cudaMemcpyHostToDevice, stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - extract_nested_lengths_kernel<<>>( - d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, - lengths.data(), num_rows, has_def_str, def_len); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); - - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_nested_varlen_data_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, output_offsets.data(), - chars.data(), num_rows, has_def_str, d_default_str.data(), def_len); - } - - rmm::device_uvector valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.begin(), - [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, num_child_fields, has_def_str] __device__(auto row) { - return (plocs[row].offset >= 0 && - flocs[row * num_child_fields + ci].offset >= 0) || has_def_str; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); - struct_children.push_back(cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); - break; - } - case cudf::type_id::STRUCT: { - // Recursively process nested struct (depth > 1) - auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); - if (gc_indices.empty()) { - struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); - break; - } - int num_gc = static_cast(gc_indices.size()); - - // Get child struct locations for grandchild scanning using GPU kernel - // IMPORTANT: Need to compute ABSOLUTE offsets (relative to row start) - // d_child_locations contains offsets relative to parent message (Middle) - // We need: child_offset_in_row = parent_offset_in_row + child_offset_in_parent - // This is computed entirely on GPU to avoid D->H->D copy pattern (performance fix!) - rmm::device_uvector d_gc_parent(num_rows, stream, mr); - compute_grandchild_parent_locations_kernel<<>>( - d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, - d_gc_parent.data(), num_rows); - - // Build grandchild field descriptors - std::vector h_gc_descs(num_gc); - for (int gi = 0; gi < num_gc; gi++) { - h_gc_descs[gi].field_number = schema[gc_indices[gi]].field_number; - h_gc_descs[gi].expected_wire_type = schema[gc_indices[gi]].wire_type; - } - rmm::device_uvector d_gc_descs(num_gc, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_gc_descs.data(), h_gc_descs.data(), - num_gc * sizeof(field_descriptor), cudaMemcpyHostToDevice, stream.value())); - - // Scan for grandchild fields - rmm::device_uvector d_gc_locs(num_rows * num_gc, stream, mr); - scan_nested_message_fields_kernel<<>>( - message_data, list_offsets, base_offset, d_gc_parent.data(), num_rows, - d_gc_descs.data(), num_gc, d_gc_locs.data(), d_error.data()); - - // Extract grandchild values (handle scalar types only) - std::vector> gc_cols; - for (int gi = 0; gi < num_gc; gi++) { - int gc_idx = gc_indices[gi]; - auto gc_dt = schema_output_types[gc_idx]; - bool gc_def = schema[gc_idx].has_default_value; - if (gc_dt.id() == cudf::type_id::INT32) { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector val(num_rows, stream, mr); - int64_t dv = gc_def ? default_ints[gc_idx] : 0; - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_gc_parent.data(), - d_gc_locs.data(), gi, num_gc, out.data(), val.data(), num_rows, d_error.data(), gc_def, dv); - auto [m, nc] = make_null_mask_from_valid(val, stream, mr); - gc_cols.push_back(std::make_unique(gc_dt, num_rows, out.release(), std::move(m), nc)); - } else { - gc_cols.push_back(make_null_column(gc_dt, num_rows, stream, mr)); - } - } - - // Build nested struct validity - rmm::device_uvector ns_valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), ns_valid.begin(), - [p = d_parent_locs.data(), c = d_child_locations.data(), ci, ncf = num_child_fields] __device__(auto r) { - return p[r].offset >= 0 && c[r * ncf + ci].offset >= 0; - }); - auto [ns_mask, ns_nc] = make_null_mask_from_valid(ns_valid, stream, mr); - struct_children.push_back(cudf::make_structs_column(num_rows, std::move(gc_cols), ns_nc, std::move(ns_mask), stream, mr)); - break; - } - default: - // For unsupported types, create null columns - struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); - break; - } - } - - // Build struct validity based on parent location - rmm::device_uvector struct_valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - struct_valid.begin(), - [plocs = d_parent_locs.data()] __device__(auto row) { - return plocs[row].offset >= 0; - }); - auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); - - column_map[parent_schema_idx] = cudf::make_structs_column( - num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); + column_map[parent_schema_idx] = build_nested_struct_column( + message_data, list_offsets, base_offset, d_parent_locs, + child_field_indices, schema, num_fields, schema_output_types, + default_ints, default_floats, default_bools, default_strings, + enum_valid_values, enum_names, d_row_has_invalid_enum, d_error, + num_rows, stream, mr, 0); } } diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index e8259d5065..3fc3e7dc97 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -29,6 +29,7 @@ namespace spark_rapids_jni { constexpr int ENC_DEFAULT = 0; constexpr int ENC_FIXED = 1; constexpr int ENC_ZIGZAG = 2; +constexpr int ENC_ENUM_STRING = 3; // Maximum nesting depth for nested messages constexpr int MAX_NESTING_DEPTH = 10; @@ -81,6 +82,8 @@ struct nested_field_descriptor { * @param default_bools Default values for bool fields * @param default_strings Default values for string/bytes fields * @param enum_valid_values Valid enum values for each field (empty if not enum) + * @param enum_names Enum names for enum-as-string fields (empty if not enum-as-string), + * ordered in parallel with enum_valid_values * @param fail_on_errors Whether to throw on malformed data * @return STRUCT column with nested structure */ @@ -93,6 +96,7 @@ std::unique_ptr decode_protobuf_to_struct( std::vector const& default_bools, std::vector> const& default_strings, std::vector> const& enum_valid_values, + std::vector>> const& enum_names, bool fail_on_errors); } // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 068ccdbe18..03ead2f4a1 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -53,6 +53,7 @@ public class Protobuf { public static final int ENC_DEFAULT = 0; public static final int ENC_FIXED = 1; public static final int ENC_ZIGZAG = 2; + public static final int ENC_ENUM_STRING = 3; // Wire type constants public static final int WT_VARINT = 0; @@ -73,7 +74,8 @@ public class Protobuf { * @param depthLevels Nesting depth for each field (0 for top-level). * @param wireTypes Expected wire type for each field (WT_VARINT, WT_64BIT, WT_LEN, WT_32BIT). * @param outputTypeIds cudf native type ids for output columns. - * @param encodings Encoding info for each field (0=default, 1=fixed, 2=zigzag). + * @param encodings Encoding info for each field (0=default, 1=fixed, 2=zigzag, + * 3=enum-as-string). * @param isRepeated Whether each field is a repeated field (array). * @param isRequired Whether each field is required (proto2). * @param hasDefaultValue Whether each field has a default value. @@ -82,6 +84,9 @@ public class Protobuf { * @param defaultBools Default values for bool fields. * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes. * @param enumValidValues Valid enum values for each field (null if not an enum). + * @param enumNames Enum value names for enum-as-string fields (null if not enum-as-string). + * For each field, this is a byte[][] containing UTF-8 enum names ordered by + * the same sorted order as enumValidValues for that field. * @param failOnErrors if true, throw an exception on malformed protobuf messages. * @return a cudf STRUCT column with nested structure. */ @@ -100,13 +105,14 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, boolean[] defaultBools, byte[][] defaultStrings, int[][] enumValidValues, + byte[][][] enumNames, boolean failOnErrors) { // Parameter validation if (fieldNumbers == null || parentIndices == null || depthLevels == null || wireTypes == null || outputTypeIds == null || encodings == null || isRepeated == null || isRequired == null || hasDefaultValue == null || defaultInts == null || defaultFloats == null || defaultBools == null || - defaultStrings == null || enumValidValues == null) { + defaultStrings == null || enumValidValues == null || enumNames == null) { throw new IllegalArgumentException("Arrays must be non-null"); } @@ -123,7 +129,8 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, defaultFloats.length != numFields || defaultBools.length != numFields || defaultStrings.length != numFields || - enumValidValues.length != numFields) { + enumValidValues.length != numFields || + enumNames.length != numFields) { throw new IllegalArgumentException("All arrays must have the same length"); } @@ -139,10 +146,11 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, // Validate encoding values for (int i = 0; i < encodings.length; i++) { int enc = encodings[i]; - if (enc < ENC_DEFAULT || enc > ENC_ZIGZAG) { + if (enc < ENC_DEFAULT || enc > ENC_ENUM_STRING) { throw new IllegalArgumentException( "Invalid encoding value at index " + i + ": " + enc + - " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", or " + ENC_ZIGZAG + ")"); + " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", " + ENC_ZIGZAG + + ", or " + ENC_ENUM_STRING + ")"); } } @@ -151,10 +159,36 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, wireTypes, outputTypeIds, encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools, - defaultStrings, enumValidValues, failOnErrors); + defaultStrings, enumValidValues, enumNames, failOnErrors); return new ColumnVector(handle); } + /** + * Backward-compatible overload for callers that don't provide enum name mappings. + * This keeps existing JNI tests and call-sites source-compatible. + */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors) { + return decodeToStruct(binaryInput, fieldNumbers, parentIndices, depthLevels, wireTypes, + outputTypeIds, encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, + new byte[fieldNumbers.length][][], failOnErrors); + } + private static native long decodeToStruct(long binaryInputView, int[] fieldNumbers, int[] parentIndices, @@ -170,5 +204,6 @@ private static native long decodeToStruct(long binaryInputView, boolean[] defaultBools, byte[][] defaultStrings, int[][] enumValidValues, + byte[][][] enumNames, boolean failOnErrors); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 20bbb0cd47..3d536445cf 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -297,7 +297,7 @@ void decodeMoreTypes() { ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); ColumnVector expectedB = ColumnVector.fromLists( - new ListType(true, new BasicType(true, DType.INT8)), + new ListType(true, new BasicType(true, DType.UINT8)), Arrays.asList((byte) 1, (byte) 2, (byte) 3)); ColumnVector actualStruct = decodeAllFields( input.getColumn(0), @@ -1725,6 +1725,58 @@ void testNestedMessage() { } } + @Test + void testDeepNestedMessageDepth3() { + // message Inner { int32 a = 1; string b = 2; bool c = 3; } + // message Middle { Inner inner = 1; int64 m = 2; } + // message Outer { Middle middle = 1; float score = 2; } + Byte[] innerMessage = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(7)), + box(tag(2, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes()), + box(tag(3, WT_VARINT)), new Byte[]{0x01}); + Byte[] middleMessage = concat( + box(tag(1, WT_LEN)), box(encodeVarint(innerMessage.length)), innerMessage, + box(tag(2, WT_VARINT)), box(encodeVarint(123L))); + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(middleMessage.length)), middleMessage, + box(tag(2, WT_32BIT)), box(encodeFloat(1.25f))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(7); + ColumnVector expectedB = ColumnVector.fromStrings("abc"); + ColumnVector expectedC = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedInner = ColumnVector.makeStruct(expectedA, expectedB, expectedC); + ColumnVector expectedM = ColumnVector.fromBoxedLongs(123L); + ColumnVector expectedMiddle = ColumnVector.makeStruct(expectedInner, expectedM); + ColumnVector expectedScore = ColumnVector.fromBoxedFloats(1.25f); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedMiddle, expectedScore); + ColumnVector actualStruct = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 1, 1, 2, 3, 2, 2}, // fieldNumbers + new int[]{-1, 0, 1, 1, 1, 0, -1}, // parentIndices + new int[]{0, 1, 2, 2, 2, 1, 0}, // depthLevels + new int[]{Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_VARINT, Protobuf.WT_LEN, + Protobuf.WT_VARINT, Protobuf.WT_VARINT, Protobuf.WT_32BIT}, // wireTypes + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId(), + DType.BOOL8.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId()}, // outputTypeIds + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT}, // encodings + new boolean[]{false, false, false, false, false, false, false}, // isRepeated + new boolean[]{false, false, false, false, false, false, false}, // isRequired + new boolean[]{false, false, false, false, false, false, false}, // hasDefaultValue + new long[]{0, 0, 0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false, false, false}, + new byte[][]{null, null, null, null, null, null, null}, + new int[][]{null, null, null, null, null, null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + // ============================================================================ // FAILFAST Mode Tests (failOnErrors = true) // ============================================================================ @@ -1929,6 +1981,125 @@ private static ColumnVector decodeAllFieldsWithEnums(ColumnView binaryInput, enumValidValues, failOnErrors); } + /** + * Helper that enables enum-as-string decoding by passing enum name mappings. + */ + private static ColumnVector decodeAllFieldsWithEnumStrings(ColumnView binaryInput, + int[] fieldNumbers, + int[][] enumValidValues, + byte[][][] enumNames, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] typeIds = new int[numFields]; + int[] encodings = new int[numFields]; + for (int i = 0; i < numFields; i++) { + typeIds[i] = DType.STRING.getTypeId().getNativeId(); + encodings[i] = Protobuf.ENC_ENUM_STRING; + } + int[] parentIndices = new int[numFields]; + int[] depthLevels = new int[numFields]; + int[] wireTypes = new int[numFields]; + boolean[] isRepeated = new boolean[numFields]; + java.util.Arrays.fill(parentIndices, -1); + java.util.Arrays.fill(wireTypes, Protobuf.WT_VARINT); + return Protobuf.decodeToStruct( + binaryInput, + fieldNumbers, + parentIndices, + depthLevels, + wireTypes, + typeIds, + encodings, + isRepeated, + new boolean[numFields], + new boolean[numFields], + new long[numFields], + new double[numFields], + new boolean[numFields], + new byte[numFields][], + enumValidValues, + enumNames, + failOnErrors); + } + + @Test + void testEnumAsStringValidValue() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // GREEN + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedField = ColumnVector.fromStrings("GREEN"); + ColumnVector expected = ColumnVector.makeStruct(expectedField); + ColumnVector actual = decodeAllFieldsWithEnumStrings( + input.getColumn(0), + new int[]{1}, + new int[][]{{0, 1, 2}}, + enumNames, + false)) { + AssertUtils.assertStructColumnsAreEqual(expected, actual); + } + } + + @Test + void testEnumAsStringUnknownValueReturnsNullRow() { + // Unknown enum value should null the entire struct row (PERMISSIVE behavior). + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeAllFieldsWithEnumStrings( + input.getColumn(0), + new int[]{1}, + new int[][]{{0, 1, 2}}, + enumNames, + false); + HostColumnVector hostStruct = actual.copyToHost()) { + assert actual.getNullCount() == 1 : "Struct row should be null for unknown enum value"; + assert hostStruct.isNull(0) : "Row 0 should be null"; + } + } + + @Test + void testEnumAsStringMixedValidAndUnknown() { + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); // RED + Byte[] row1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // unknown + Byte[] row2 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(2))); // BLUE + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); + ColumnVector actual = decodeAllFieldsWithEnumStrings( + input.getColumn(0), + new int[]{1}, + new int[][]{{0, 1, 2}}, + enumNames, + false); + HostColumnVector hostStruct = actual.copyToHost()) { + assert actual.getNullCount() == 1 : "Only the unknown enum row should be null"; + assert !hostStruct.isNull(0) : "Row 0 should be valid"; + assert hostStruct.isNull(1) : "Row 1 should be null"; + assert !hostStruct.isNull(2) : "Row 2 should be valid"; + } + } + @Test void testEnumValidValue() { // enum Color { RED=0; GREEN=1; BLUE=2; } @@ -2070,4 +2241,50 @@ void testEnumValidWithOtherFields() { AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); } } + + // ============================================================================ + // Repeated Enum-as-String Tests + // ============================================================================ + + @Test + void testRepeatedEnumAsString() { + // repeated Color colors = 1; with Color { RED=0; GREEN=1; BLUE=2; } + // Row with three occurrences: RED, BLUE, GREEN + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(0)), // RED + box(tag(1, WT_VARINT)), box(encodeVarint(2)), // BLUE + box(tag(1, WT_VARINT)), box(encodeVarint(1))); // GREEN + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ENUM_STRING}, + new boolean[]{true}, // isRepeated + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{{0, 1, 2}}, + enumNames, + false)) { + assertNotNull(actual); + assertEquals(DType.STRUCT, actual.getType()); + // The struct has 1 child: a LIST column with ["RED", "BLUE", "GREEN"] + assertEquals(1, actual.getNumChildren()); + } + } } From 7e0e77d00049fb7728ed1431f726b28f21174376 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Feb 2026 16:03:32 +0800 Subject: [PATCH 022/107] add a tests Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 1 + .../nvidia/spark/rapids/jni/ProtobufTest.java | 108 ++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 841b78268c..0bb069a9f9 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -4223,6 +4223,7 @@ std::unique_ptr decode_protobuf_to_struct( field_info[row] = h_repeated_info[row * num_repeated + ri]; } + if (total_count > 0) { // Build offsets for occurrence scanning on GPU (performance fix!) rmm::device_uvector d_occ_offsets(num_rows + 1, stream, mr); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 3d536445cf..51054c2f54 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1687,6 +1687,114 @@ void testUnpackedRepeatedInt32() { } } + @Test + void testPackedRepeatedDoubleWithMultipleFields() { + // Test packed repeated fields with multiple types including edge cases. + // message WithPackedRepeated { + // optional int32 id = 1; + // repeated int32 int_values = 2 [packed=true]; + // repeated double double_values = 3 [packed=true]; + // repeated bool bool_values = 4 [packed=true]; + // } + + // Helper to build packed int data (varints) + java.io.ByteArrayOutputStream intBuf = new java.io.ByteArrayOutputStream(); + + // Row 0: id=42, int_values=[1,-1,100] (12 bytes packed), double_values=[1.5,2.5], bool=[true,false] + // Row 1: id=7, int_values=15x(-1) (150 bytes packed, 2-byte length varint!), double_values=[3.0,4.0], bool=[true] + // Row 2: id=0, int_values=[] (field omitted), double_values=[5.0], bool=[] (field omitted) + + // --- Row 0 --- + byte[] r0IntVarints = concatBytes(encodeVarint(1), encodeVarint(-1L & 0xFFFFFFFFFFFFFFFFL), encodeVarint(100)); + byte[] r0Doubles = concatBytes(encodeDouble(1.5), encodeDouble(2.5)); + byte[] r0Bools = new byte[]{0x01, 0x00}; + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(42)), + box(tag(2, WT_LEN)), box(encodeVarint(r0IntVarints.length)), box(r0IntVarints), + box(tag(3, WT_LEN)), box(encodeVarint(r0Doubles.length)), box(r0Doubles), + box(tag(4, WT_LEN)), box(encodeVarint(r0Bools.length)), box(r0Bools)); + + // --- Row 1: 15 negative ints => 150 bytes packed (length varint is 2 bytes: 0x96 0x01) --- + java.io.ByteArrayOutputStream buf1 = new java.io.ByteArrayOutputStream(); + byte[] negOneVarint = encodeVarint(-1L & 0xFFFFFFFFFFFFFFFFL); // 10 bytes + for (int i = 0; i < 15; i++) { + buf1.write(negOneVarint, 0, negOneVarint.length); + } + byte[] r1IntVarints = buf1.toByteArray(); // 150 bytes + byte[] r1Doubles = concatBytes(encodeDouble(3.0), encodeDouble(4.0)); + byte[] r1Bools = new byte[]{0x01}; + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(7)), + box(tag(2, WT_LEN)), box(encodeVarint(r1IntVarints.length)), box(r1IntVarints), + box(tag(3, WT_LEN)), box(encodeVarint(r1Doubles.length)), box(r1Doubles), + box(tag(4, WT_LEN)), box(encodeVarint(r1Bools.length)), box(r1Bools)); + + // --- Row 2: no int_values, no bool_values --- + byte[] r2Doubles = encodeDouble(5.0); + Byte[] row2 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(0)), + box(tag(3, WT_LEN)), box(encodeVarint(r2Doubles.length)), box(r2Doubles)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, row1, row2}).build()) { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 2, 3, 4}, + new int[]{-1, -1, -1, -1}, + new int[]{0, 0, 0, 0}, + new int[]{WT_VARINT, WT_VARINT, WT_64BIT, WT_VARINT}, + new int[]{ + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.FLOAT64.getTypeId().getNativeId(), + DType.BOOL8.getTypeId().getNativeId() + }, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, true, true, true}, + new boolean[]{false, false, false, false}, + new boolean[]{false, false, false, false}, + new long[]{0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false}, + new byte[][]{null, null, null, null}, + new int[][]{null, null, null, null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(3, result.getRowCount()); + + // Verify double_values child column has correct total count: 2 + 2 + 1 = 5 + try (ColumnVector doubleListCol = result.getChildColumnView(2).copyToColumnVector()) { + assertEquals(DType.LIST, doubleListCol.getType()); + try (ColumnVector doubleChildren = doubleListCol.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.FLOAT64, doubleChildren.getType()); + assertEquals(5, doubleChildren.getRowCount(), + "Total packed doubles across 3 rows should be 5, got " + doubleChildren.getRowCount()); + try (HostColumnVector hd = doubleChildren.copyToHost()) { + assertEquals(1.5, hd.getDouble(0), 1e-10); + assertEquals(2.5, hd.getDouble(1), 1e-10); + assertEquals(3.0, hd.getDouble(2), 1e-10); + assertEquals(4.0, hd.getDouble(3), 1e-10); + assertEquals(5.0, hd.getDouble(4), 1e-10); + } + } + } + } + } + } + + /** Helper: concatenate byte arrays */ + private static byte[] concatBytes(byte[]... arrays) { + int len = 0; + for (byte[] a : arrays) len += a.length; + byte[] out = new byte[len]; + int pos = 0; + for (byte[] a : arrays) { + System.arraycopy(a, 0, out, pos, a.length); + pos += a.length; + } + return out; + } + @Test void testNestedMessage() { // message Inner { int32 x = 1; } From ac39a4e533a59abdadd707bd228b890fcbfd0bdf Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 25 Feb 2026 11:31:18 +0800 Subject: [PATCH 023/107] Kernal code clean up Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 690 +++++++++++++++++------------------ 1 file changed, 337 insertions(+), 353 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 0bb069a9f9..df6c1c6b36 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -40,22 +40,35 @@ #include #include -#include #include #include namespace { -// Wire type constants +// Wire type constants (protobuf encoding spec) constexpr int WT_VARINT = 0; constexpr int WT_64BIT = 1; constexpr int WT_LEN = 2; constexpr int WT_32BIT = 5; -} // namespace +// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. +constexpr int MAX_VARINT_BYTES = 10; -namespace { +// CUDA kernel launch configuration. +constexpr int THREADS_PER_BLOCK = 256; + +// Error codes for kernel error reporting. +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +// Maximum supported nesting depth for recursive struct decoding. constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; /** @@ -119,7 +132,9 @@ __device__ inline bool read_varint(uint8_t const* cur, out = 0; bytes = 0; int shift = 0; - while (cur < end && bytes < 10) { + // Protobuf varint uses 7 bits per byte with MSB as continuation flag. + // A 64-bit value requires at most ceil(64/7) = 10 bytes. + while (cur < end && bytes < MAX_VARINT_BYTES) { uint8_t b = *cur++; // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid if (bytes == 9 && (b & 0xFE) != 0) { @@ -139,7 +154,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con case WT_VARINT: { // Need to scan to find the end of varint int count = 0; - while (cur < end && count < 10) { + while (cur < end && count < MAX_VARINT_BYTES) { if ((*cur++ & 0x80u) == 0) { return count + 1; } count++; } @@ -157,7 +172,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con uint64_t len; int n; if (!read_varint(cur, end, len, n)) return -1; - if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX)) + if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX - n)) return -1; return n + static_cast(len); } @@ -209,6 +224,49 @@ __device__ inline bool get_field_data_location(uint8_t const* cur, return true; } +__device__ inline bool check_message_bounds(int32_t start, + int32_t end_pos, + cudf::size_type total_size, + int* error_flag) +{ + if (start < 0 || end_pos < start || end_pos > total_size) { + atomicExch(error_flag, ERR_BOUNDS); + return false; + } + return true; +} + +struct proto_tag { + int field_number; + int wire_type; +}; + +__device__ inline bool decode_tag(uint8_t const*& cur, + uint8_t const* end, + proto_tag& tag, + int* error_flag) +{ + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) { + atomicExch(error_flag, ERR_VARINT); + return false; + } + + cur += key_bytes; + tag.field_number = static_cast(key >> 3); + tag.wire_type = static_cast(key & 0x7); + if (tag.field_number == 0) { + atomicExch(error_flag, ERR_FIELD_NUMBER); + return false; + } + return true; +} + +/** + * Load a little-endian value from unaligned memory. + * Reads bytes individually to avoid unaligned-access issues on GPU. + */ template __device__ inline T load_le(uint8_t const* p); @@ -240,6 +298,8 @@ __device__ inline uint64_t load_le(uint8_t const* p) * * For "last one wins" semantics (protobuf standard for repeated scalars), * we continue scanning even after finding a field. + * + * @note Time complexity: O(message_length * num_fields) per row. */ __global__ void scan_all_fields_kernel( cudf::column_device_view const d_in, @@ -264,42 +324,27 @@ __global__ void scan_all_fields_kernel( auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; - // Bounds check - if (start < 0 || end < start || end > child.size()) { - atomicExch(error_flag, 1); - return; - } + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - uint8_t const* cur = bytes + start; - uint8_t const* stop = bytes + end; + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; // Scan the message once, recording locations of all target fields - while (cur < stop) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, stop, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; - - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - if (fn == 0) { - atomicExch(error_flag, 1); - return; - } + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; // Check if this field is one we're looking for for (int f = 0; f < num_fields; f++) { if (field_descs[f].field_number == fn) { // Check wire type matches if (wt != field_descs[f].expected_wire_type) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_WIRE_TYPE); return; } @@ -310,22 +355,22 @@ __global__ void scan_all_fields_kernel( // For length-delimited, record offset after length prefix and the data length uint64_t len; int len_bytes; - if (!read_varint(cur, stop, len, len_bytes)) { - atomicExch(error_flag, 1); + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, ERR_VARINT); return; } - if (len > static_cast(stop - cur - len_bytes) || + if (len > static_cast(msg_end - cur - len_bytes) || len > static_cast(INT_MAX)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_OVERFLOW); return; } // Record offset pointing to the actual data (after length prefix) locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; } else { // For fixed-size and varint fields, record offset and compute length - int field_size = get_wire_type_size(wt, cur, stop); + int field_size = get_wire_type_size(wt, cur, msg_end); if (field_size < 0) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_FIELD_SIZE); return; } locations[row * num_fields + f] = {data_offset, field_size}; @@ -336,8 +381,8 @@ __global__ void scan_all_fields_kernel( // Skip to next field uint8_t const* next; - if (!skip_field(cur, stop, wt, next)) { - atomicExch(error_flag, 1); + if (!skip_field(cur, msg_end, wt, next)) { + atomicExch(error_flag, ERR_SKIP); return; } cur = next; @@ -351,6 +396,8 @@ __global__ void scan_all_fields_kernel( /** * Count occurrences of repeated fields in each row. * Also records locations of nested message fields for hierarchical processing. + * + * @note Time complexity: O(message_length * (num_repeated_fields + num_nested_fields)) per row. */ __global__ void count_repeated_fields_kernel( cudf::column_device_view const d_in, @@ -386,33 +433,18 @@ __global__ void count_repeated_fields_kernel( auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - if (start < 0 || end < start || end > child.size()) { - atomicExch(error_flag, 1); - return; - } - - uint8_t const* cur = bytes + start; - uint8_t const* stop = bytes + end; + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; - while (cur < stop) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, stop, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; - - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - if (fn == 0) { - atomicExch(error_flag, 1); - return; - } + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; // Check repeated fields at this depth for (int i = 0; i < num_repeated_fields; i++) { @@ -425,7 +457,7 @@ __global__ void count_repeated_fields_kernel( bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); if (!is_packed && wt != expected_wt) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_WIRE_TYPE); return; } @@ -433,16 +465,16 @@ __global__ void count_repeated_fields_kernel( // Packed encoding: read length, then count elements inside uint64_t packed_len; int len_bytes; - if (!read_varint(cur, stop, packed_len, len_bytes)) { - atomicExch(error_flag, 1); + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + atomicExch(error_flag, ERR_VARINT); return; } // Count elements based on type uint8_t const* packed_start = cur + len_bytes; uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > stop) { - atomicExch(error_flag, 1); + if (packed_end > msg_end) { + atomicExch(error_flag, ERR_OVERFLOW); return; } @@ -454,7 +486,7 @@ __global__ void count_repeated_fields_kernel( uint64_t dummy; int vbytes; if (!read_varint(p, packed_end, dummy, vbytes)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); return; } p += vbytes; @@ -471,8 +503,8 @@ __global__ void count_repeated_fields_kernel( } else { // Non-packed encoding: single element int32_t data_offset, data_length; - if (!get_field_data_location(cur, stop, wt, data_offset, data_length)) { - atomicExch(error_flag, 1); + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + atomicExch(error_flag, ERR_FIELD_SIZE); return; } @@ -487,14 +519,14 @@ __global__ void count_repeated_fields_kernel( int schema_idx = nested_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { if (wt != WT_LEN) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_WIRE_TYPE); return; } uint64_t len; int len_bytes; - if (!read_varint(cur, stop, len, len_bytes)) { - atomicExch(error_flag, 1); + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, ERR_VARINT); return; } @@ -505,8 +537,8 @@ __global__ void count_repeated_fields_kernel( // Skip to next field uint8_t const* next; - if (!skip_field(cur, stop, wt, next)) { - atomicExch(error_flag, 1); + if (!skip_field(cur, msg_end, wt, next)) { + atomicExch(error_flag, ERR_SKIP); return; } cur = next; @@ -516,6 +548,8 @@ __global__ void count_repeated_fields_kernel( /** * Scan and record all occurrences of repeated fields. * Called after count_repeated_fields_kernel to fill in actual locations. + * + * @note Time complexity: O(message_length * num_repeated_fields) per row. */ __global__ void scan_repeated_field_occurrences_kernel( cudf::column_device_view const d_in, @@ -537,37 +571,22 @@ __global__ void scan_repeated_field_occurrences_kernel( auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - auto start = in.offset_at(row) - base; - auto end = in.offset_at(row + 1) - base; + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - if (start < 0 || end < start || end > child.size()) { - atomicExch(error_flag, 1); - return; - } - - uint8_t const* cur = bytes + start; - uint8_t const* stop = bytes + end; + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; int target_fn = schema[schema_idx].field_number; int target_wt = schema[schema_idx].wire_type; int write_idx = output_offsets[row]; - while (cur < stop) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, stop, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; - - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - if (fn == 0) { - atomicExch(error_flag, 1); - return; - } + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; if (fn == target_fn) { // Check for packed encoding: wire type LEN but expected non-LEN @@ -577,15 +596,15 @@ __global__ void scan_repeated_field_occurrences_kernel( // Packed encoding: multiple elements in a length-delimited blob uint64_t packed_len; int len_bytes; - if (!read_varint(cur, stop, packed_len, len_bytes)) { - atomicExch(error_flag, 1); + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + atomicExch(error_flag, ERR_VARINT); return; } uint8_t const* packed_start = cur + len_bytes; uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > stop) { - atomicExch(error_flag, 1); + if (packed_end > msg_end) { + atomicExch(error_flag, ERR_OVERFLOW); return; } @@ -598,7 +617,7 @@ __global__ void scan_repeated_field_occurrences_kernel( uint64_t dummy; int vbytes; if (!read_varint(p, packed_end, dummy, vbytes)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); return; } occurrences[write_idx] = {static_cast(row), elem_offset, vbytes}; @@ -627,8 +646,8 @@ __global__ void scan_repeated_field_occurrences_kernel( } else if (wt == target_wt) { // Non-packed encoding: single element int32_t data_offset, data_length; - if (!get_field_data_location(cur, stop, wt, data_offset, data_length)) { - atomicExch(error_flag, 1); + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + atomicExch(error_flag, ERR_FIELD_SIZE); return; } @@ -640,8 +659,8 @@ __global__ void scan_repeated_field_occurrences_kernel( // Skip to next field uint8_t const* next; - if (!skip_field(cur, stop, wt, next)) { - atomicExch(error_flag, 1); + if (!skip_field(cur, msg_end, wt, next)) { + atomicExch(error_flag, ERR_SKIP); return; } cur = next; @@ -656,7 +675,7 @@ __global__ void scan_repeated_field_occurrences_kernel( * Extract varint field data using pre-recorded locations. * Supports default values for missing fields. */ -template +template __global__ void extract_varint_from_locations_kernel( uint8_t const* message_data, cudf::size_type const* offsets, // List offsets for each row @@ -664,7 +683,7 @@ __global__ void extract_varint_from_locations_kernel( field_location const* locations, // [num_rows * num_fields] int field_idx, int num_fields, - OutT* out, + OutputType* out, bool* valid, int num_rows, int* error_flag, @@ -678,7 +697,7 @@ __global__ void extract_varint_from_locations_kernel( if (loc.offset < 0) { // Field not found - use default value if available if (has_default) { - out[row] = static_cast(default_value); + out[row] = static_cast(default_value); valid[row] = true; } else { valid[row] = false; @@ -694,13 +713,13 @@ __global__ void extract_varint_from_locations_kernel( uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); valid[row] = false; return; } if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[row] = static_cast(v); + out[row] = static_cast(v); valid[row] = true; } @@ -708,19 +727,19 @@ __global__ void extract_varint_from_locations_kernel( * Extract fixed-size field data (fixed32, fixed64, float, double). * Supports default values for missing fields. */ -template +template __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, cudf::size_type const* offsets, cudf::size_type base_offset, field_location const* locations, int field_idx, int num_fields, - OutT* out, + OutputType* out, bool* valid, int num_rows, int* error_flag, bool has_default = false, - OutT default_value = OutT{}) + OutputType default_value = OutputType{}) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -740,10 +759,10 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, auto row_start = offsets[row] - base_offset; uint8_t const* cur = message_data + row_start + loc.offset; - OutT value; + OutputType value; if constexpr (WT == WT_32BIT) { if (loc.length < 4) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -751,7 +770,7 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, memcpy(&value, &raw, sizeof(value)); } else { if (loc.length < 8) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -770,14 +789,14 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, /** * Extract repeated varint values using pre-recorded occurrences. */ -template +template __global__ void extract_repeated_varint_kernel( uint8_t const* message_data, cudf::size_type const* row_offsets, cudf::size_type base_offset, repeated_occurrence const* occurrences, int total_occurrences, - OutT* out, + OutputType* out, int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -791,26 +810,26 @@ __global__ void extract_repeated_varint_kernel( uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, 1); - out[idx] = OutT{}; + atomicExch(error_flag, ERR_VARINT); + out[idx] = OutputType{}; return; } if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[idx] = static_cast(v); + out[idx] = static_cast(v); } /** * Extract repeated fixed-size values using pre-recorded occurrences. */ -template +template __global__ void extract_repeated_fixed_kernel( uint8_t const* message_data, cudf::size_type const* row_offsets, cudf::size_type base_offset, repeated_occurrence const* occurrences, int total_occurrences, - OutT* out, + OutputType* out, int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -820,19 +839,19 @@ __global__ void extract_repeated_fixed_kernel( auto row_start = row_offsets[occ.row_idx] - base_offset; uint8_t const* cur = message_data + row_start + occ.offset; - OutT value; + OutputType value; if constexpr (WT == WT_32BIT) { if (occ.length < 4) { - atomicExch(error_flag, 1); - out[idx] = OutT{}; + atomicExch(error_flag, ERR_FIXED_LEN); + out[idx] = OutputType{}; return; } uint32_t raw = load_le(cur); memcpy(&value, &raw, sizeof(value)); } else { if (occ.length < 8) { - atomicExch(error_flag, 1); - out[idx] = OutT{}; + atomicExch(error_flag, ERR_FIXED_LEN); + out[idx] = OutputType{}; return; } uint64_t raw = load_le(cur); @@ -864,9 +883,7 @@ __global__ void copy_repeated_varlen_data_kernel( uint8_t const* src = message_data + row_start + occ.offset; char* dst = output_data + output_offsets[idx]; - for (int i = 0; i < occ.length; i++) { - dst[i] = static_cast(src[i]); - } + memcpy(dst, src, occ.length); } /** @@ -922,26 +939,15 @@ __global__ void scan_nested_message_fields_kernel( uint8_t const* cur = nested_start; while (cur < nested_end) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, nested_end, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; - - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - if (fn == 0) { - atomicExch(error_flag, 1); - return; - } + proto_tag tag; + if (!decode_tag(cur, nested_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; for (int f = 0; f < num_fields; f++) { if (field_descs[f].field_number == fn) { if (wt != field_descs[f].expected_wire_type) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_WIRE_TYPE); return; } @@ -951,19 +957,19 @@ __global__ void scan_nested_message_fields_kernel( uint64_t len; int len_bytes; if (!read_varint(cur, nested_end, len, len_bytes)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); return; } if (len > static_cast(nested_end - cur - len_bytes) || len > static_cast(INT_MAX)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_OVERFLOW); return; } output_locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; } else { int field_size = get_wire_type_size(wt, cur, nested_end); if (field_size < 0) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_FIELD_SIZE); return; } output_locations[row * num_fields + f] = {data_offset, field_size}; @@ -973,15 +979,18 @@ __global__ void scan_nested_message_fields_kernel( uint8_t const* next; if (!skip_field(cur, nested_end, wt, next)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_SKIP); return; } cur = next; } } -// Utility function: make_null_mask_from_valid -// (Moved here to be available for repeated message child extraction) +/** + * Build a null bitmask from a boolean validity array. + * @param valid Device vector where valid[i] indicates row i validity. + * @return Pair of (null mask buffer, null count). + */ template inline std::pair make_null_mask_from_valid( rmm::device_uvector const& valid, @@ -1027,21 +1036,10 @@ __global__ void scan_repeated_message_children_kernel( uint8_t const* cur = msg_start; while (cur < msg_end) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, msg_end, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; - - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); - - if (fn == 0) { - atomicExch(error_flag, 1); - return; - } + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; // Check against child field descriptors for (int f = 0; f < num_child_fields; f++) { @@ -1058,7 +1056,7 @@ __global__ void scan_repeated_message_children_kernel( uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); return; } // Store offset (after length prefix) and length @@ -1086,7 +1084,7 @@ __global__ void scan_repeated_message_children_kernel( // Skip to next field uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_SKIP); return; } cur = next; @@ -1110,35 +1108,29 @@ __global__ void count_repeated_in_nested_kernel( int const* repeated_indices, int* error_flag) { - auto row_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row_idx >= num_rows) return; + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; // Initialize counts for (int ri = 0; ri < num_repeated; ri++) { - repeated_info[row_idx * num_repeated + ri] = {0, 0}; + repeated_info[row * num_repeated + ri] = {0, 0}; } - auto const& parent_loc = parent_locs[row_idx]; + auto const& parent_loc = parent_locs[row]; if (parent_loc.offset < 0) return; cudf::size_type row_off; - row_off = row_offsets[row_idx] - base_offset; + row_off = row_offsets[row] - base_offset; uint8_t const* msg_start = message_data + row_off + parent_loc.offset; uint8_t const* msg_end = msg_start + parent_loc.length; uint8_t const* cur = msg_start; while (cur < msg_end) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, msg_end, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; - - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; // Check if this is one of our repeated fields for (int ri = 0; ri < num_repeated; ri++) { @@ -1149,19 +1141,19 @@ __global__ void count_repeated_in_nested_kernel( uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); return; } data_len = static_cast(len); } - repeated_info[row_idx * num_repeated + ri].count++; - repeated_info[row_idx * num_repeated + ri].total_length += data_len; + repeated_info[row * num_repeated + ri].count++; + repeated_info[row * num_repeated + ri].total_length += data_len; } } uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_SKIP); return; } cur = next; @@ -1179,25 +1171,22 @@ __global__ void scan_repeated_in_nested_kernel( int num_rows, device_nested_field_descriptor const* schema, int num_fields, - repeated_field_info const* repeated_info, + int32_t const* occ_prefix_sums, int num_repeated, int const* repeated_indices, repeated_occurrence* occurrences, int* error_flag) { - auto row_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row_idx >= num_rows) return; + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; - auto const& parent_loc = parent_locs[row_idx]; + auto const& parent_loc = parent_locs[row]; if (parent_loc.offset < 0) return; - // Calculate output offset for this row - int occ_offset = 0; - for (int r = 0; r < row_idx; r++) { - occ_offset += repeated_info[r * num_repeated].count; - } + // Prefix sum gives the write start offset for this row. + int occ_offset = occ_prefix_sums[row]; - cudf::size_type row_off = row_offsets[row_idx] - base_offset; + cudf::size_type row_off = row_offsets[row] - base_offset; uint8_t const* msg_start = message_data + row_off + parent_loc.offset; uint8_t const* msg_end = msg_start + parent_loc.length; @@ -1206,51 +1195,47 @@ __global__ void scan_repeated_in_nested_kernel( int occ_idx = 0; while (cur < msg_end) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, msg_end, key, key_bytes)) { - atomicExch(error_flag, 1); - return; - } - cur += key_bytes; + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; - int fn = static_cast(key >> 3); - int wt = static_cast(key & 0x7); + // Check if this is one of our repeated fields. + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_indices[ri]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + int32_t data_offset = static_cast(cur - msg_start); + int32_t data_len = 0; - // Check if this is our repeated field (assuming single repeated field for simplicity) - int schema_idx = repeated_indices[0]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int32_t data_offset = static_cast(cur - msg_start); - int32_t data_len = 0; - - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, 1); - return; - } - data_offset += len_bytes; - data_len = static_cast(len); - } else if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { - data_len = vbytes; + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, ERR_VARINT); + return; + } + data_offset += len_bytes; + data_len = static_cast(len); + } else if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (read_varint(cur, msg_end, dummy, vbytes)) { + data_len = vbytes; + } + } else if (wt == WT_32BIT) { + data_len = 4; + } else if (wt == WT_64BIT) { + data_len = 8; } - } else if (wt == WT_32BIT) { - data_len = 4; - } else if (wt == WT_64BIT) { - data_len = 8; + + occurrences[occ_offset + occ_idx] = {row, data_offset, data_len}; + occ_idx++; } - - occurrences[occ_offset + occ_idx] = {row_idx, data_offset, data_len}; - occ_idx++; } uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_SKIP); return; } cur = next; @@ -1260,7 +1245,7 @@ __global__ void scan_repeated_in_nested_kernel( /** * Extract varint values from repeated field occurrences within nested messages. */ -template +template __global__ void extract_repeated_in_nested_varint_kernel( uint8_t const* message_data, cudf::size_type const* row_offsets, @@ -1268,7 +1253,7 @@ __global__ void extract_repeated_in_nested_varint_kernel( field_location const* parent_locs, repeated_occurrence const* occurrences, int total_count, - OutT* out, + OutputType* out, int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -1279,22 +1264,23 @@ __global__ void extract_repeated_in_nested_varint_kernel( cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; + uint8_t const* msg_end = message_data + row_off + parent_loc.offset + parent_loc.length; + uint8_t const* varint_end = + (data_ptr + MAX_VARINT_BYTES < msg_end) ? (data_ptr + MAX_VARINT_BYTES) : msg_end; uint64_t val; int vbytes; - if (!read_varint(data_ptr, data_ptr + 10, val, vbytes)) { - atomicExch(error_flag, 1); + if (!read_varint(data_ptr, varint_end, val, vbytes)) { + atomicExch(error_flag, ERR_VARINT); return; } - if constexpr (ZigZag) { - val = (val >> 1) ^ (~(val & 1) + 1); - } + if constexpr (ZigZag) { val = (val >> 1) ^ (-(val & 1)); } - out[idx] = static_cast(val); + out[idx] = static_cast(val); } -template +template __global__ void extract_repeated_in_nested_fixed_kernel( uint8_t const* message_data, cudf::size_type const* row_offsets, @@ -1302,7 +1288,7 @@ __global__ void extract_repeated_in_nested_fixed_kernel( field_location const* parent_locs, repeated_occurrence const* occurrences, int total_count, - OutT* out, + OutputType* out, int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -1316,20 +1302,20 @@ __global__ void extract_repeated_in_nested_fixed_kernel( if constexpr (WT == WT_32BIT) { if (occ.length < 4) { - atomicExch(error_flag, 1); - out[idx] = OutT{}; + atomicExch(error_flag, ERR_FIXED_LEN); + out[idx] = OutputType{}; return; } uint32_t raw = load_le(data_ptr); - memcpy(&out[idx], &raw, sizeof(OutT)); + memcpy(&out[idx], &raw, sizeof(OutputType)); } else { if (occ.length < 8) { - atomicExch(error_flag, 1); - out[idx] = OutT{}; + atomicExch(error_flag, ERR_FIXED_LEN); + out[idx] = OutputType{}; return; } uint64_t raw = load_le(data_ptr); - memcpy(&out[idx], &raw, sizeof(OutT)); + memcpy(&out[idx], &raw, sizeof(OutputType)); } } @@ -1357,15 +1343,13 @@ __global__ void extract_repeated_in_nested_string_kernel( uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; int32_t out_offset = str_offsets[idx]; - for (int32_t i = 0; i < occ.length; i++) { - chars[out_offset + i] = static_cast(data_ptr[i]); - } + memcpy(chars + out_offset, data_ptr, occ.length); } /** * Extract varint child fields from repeated message occurrences. */ -template +template __global__ void extract_repeated_msg_child_varint_kernel( uint8_t const* message_data, int32_t const* msg_row_offsets, @@ -1373,53 +1357,54 @@ __global__ void extract_repeated_msg_child_varint_kernel( field_location const* child_locs, int child_idx, int num_child_fields, - OutT* out, + OutputType* out, bool* valid, int num_occurrences, int* error_flag, bool has_default = false, int64_t default_value = 0) { - auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (occ_idx >= num_occurrences) return; + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= num_occurrences) return; - auto const& msg_loc = msg_locs[occ_idx]; - auto const& field_loc = child_locs[occ_idx * num_child_fields + child_idx]; + auto const& msg_loc = msg_locs[idx]; + auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; if (msg_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[occ_idx] = static_cast(default_value); - valid[occ_idx] = true; + out[idx] = static_cast(default_value); + valid[idx] = true; } else { - valid[occ_idx] = false; + valid[idx] = false; } return; } - int32_t row_offset = msg_row_offsets[occ_idx]; + int32_t row_offset = msg_row_offsets[idx]; uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; uint8_t const* cur = msg_start + field_loc.offset; + uint8_t const* msg_end = msg_start + msg_loc.length; + uint8_t const* varint_end = + (cur + MAX_VARINT_BYTES < msg_end) ? (cur + MAX_VARINT_BYTES) : msg_end; uint64_t val; int vbytes; - if (!read_varint(cur, cur + 10, val, vbytes)) { - atomicExch(error_flag, 1); - valid[occ_idx] = false; + if (!read_varint(cur, varint_end, val, vbytes)) { + atomicExch(error_flag, ERR_VARINT); + valid[idx] = false; return; } - if constexpr (ZigZag) { - val = (val >> 1) ^ (~(val & 1) + 1); - } + if constexpr (ZigZag) { val = (val >> 1) ^ (-(val & 1)); } - out[occ_idx] = static_cast(val); - valid[occ_idx] = true; + out[idx] = static_cast(val); + valid[idx] = true; } /** * Extract fixed-size child fields from repeated message occurrences. */ -template +template __global__ void extract_repeated_msg_child_fixed_kernel( uint8_t const* message_data, int32_t const* msg_row_offsets, @@ -1427,34 +1412,34 @@ __global__ void extract_repeated_msg_child_fixed_kernel( field_location const* child_locs, int child_idx, int num_child_fields, - OutT* out, + OutputType* out, bool* valid, int num_occurrences, int* error_flag, bool has_default = false, - OutT default_value = OutT{}) + OutputType default_value = OutputType{}) { - auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (occ_idx >= num_occurrences) return; + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= num_occurrences) return; - auto const& msg_loc = msg_locs[occ_idx]; - auto const& field_loc = child_locs[occ_idx * num_child_fields + child_idx]; + auto const& msg_loc = msg_locs[idx]; + auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; if (msg_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[occ_idx] = default_value; - valid[occ_idx] = true; + out[idx] = default_value; + valid[idx] = true; } else { - valid[occ_idx] = false; + valid[idx] = false; } return; } - int32_t row_offset = msg_row_offsets[occ_idx]; + int32_t row_offset = msg_row_offsets[idx]; uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; uint8_t const* cur = msg_start + field_loc.offset; - OutT value; + OutputType value; if constexpr (WT == WT_32BIT) { uint32_t raw = load_le(cur); memcpy(&value, &raw, sizeof(value)); @@ -1463,8 +1448,8 @@ __global__ void extract_repeated_msg_child_fixed_kernel( memcpy(&value, &raw, sizeof(value)); } - out[occ_idx] = value; - valid[occ_idx] = true; + out[idx] = value; + valid[idx] = true; } /** @@ -1501,9 +1486,7 @@ __global__ void extract_repeated_msg_child_strings_kernel( char* str_dst = output_chars + string_offsets[idx]; // Copy string data - for (int i = 0; i < field_loc.length; i++) { - str_dst[i] = static_cast(str_src[i]); - } + memcpy(str_dst, str_src, field_loc.length); } /** @@ -1543,7 +1526,7 @@ inline std::unique_ptr build_repeated_msg_child_string_column( return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); } - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; // Compute string lengths on GPU @@ -1625,7 +1608,7 @@ inline std::unique_ptr build_repeated_msg_child_bytes_column( 0, rmm::device_buffer{}, stream, mr); } - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; rmm::device_uvector d_lengths(total_count, stream, mr); @@ -1796,7 +1779,7 @@ struct extract_strided_count { /** * Extract varint from nested message locations. */ -template +template __global__ void extract_nested_varint_kernel( uint8_t const* message_data, cudf::size_type const* parent_row_offsets, @@ -1805,7 +1788,7 @@ __global__ void extract_nested_varint_kernel( field_location const* field_locations, int field_idx, int num_fields, - OutT* out, + OutputType* out, bool* valid, int num_rows, int* error_flag, @@ -1820,7 +1803,7 @@ __global__ void extract_nested_varint_kernel( if (parent_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[row] = static_cast(default_value); + out[row] = static_cast(default_value); valid[row] = true; } else { valid[row] = false; @@ -1835,20 +1818,20 @@ __global__ void extract_nested_varint_kernel( uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_VARINT); valid[row] = false; return; } if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[row] = static_cast(v); + out[row] = static_cast(v); valid[row] = true; } /** * Extract fixed-size from nested message locations. */ -template +template __global__ void extract_nested_fixed_kernel( uint8_t const* message_data, cudf::size_type const* parent_row_offsets, @@ -1857,12 +1840,12 @@ __global__ void extract_nested_fixed_kernel( field_location const* field_locations, int field_idx, int num_fields, - OutT* out, + OutputType* out, bool* valid, int num_rows, int* error_flag, bool has_default = false, - OutT default_value = OutT{}) + OutputType default_value = OutputType{}) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -1883,10 +1866,10 @@ __global__ void extract_nested_fixed_kernel( auto parent_row_start = parent_row_offsets[row] - parent_base_offset; uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; - OutT value; + OutputType value; if constexpr (WT == WT_32BIT) { if (field_loc.length < 4) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -1894,7 +1877,7 @@ __global__ void extract_nested_fixed_kernel( memcpy(&value, &raw, sizeof(value)); } else { if (field_loc.length < 8) { - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -1934,9 +1917,7 @@ __global__ void copy_nested_varlen_data_kernel( if (parent_loc.offset < 0 || field_loc.offset < 0) { if (has_default && default_length > 0) { - for (int i = 0; i < default_length; i++) { - dst[i] = static_cast(default_data[i]); - } + memcpy(dst, default_data, default_length); } return; } @@ -1946,9 +1927,7 @@ __global__ void copy_nested_varlen_data_kernel( auto parent_row_start = parent_row_offsets[row] - parent_base_offset; uint8_t const* src = message_data + parent_row_start + parent_loc.offset + field_loc.offset; - for (int i = 0; i < field_loc.length; i++) { - dst[i] = static_cast(src[i]); - } + memcpy(dst, src, field_loc.length); } /** @@ -2034,9 +2013,7 @@ __global__ void copy_scalar_string_data_kernel( if (loc.offset < 0) { // Field not found - use default if available if (has_default && default_length > 0) { - for (int i = 0; i < default_length; i++) { - dst[i] = static_cast(default_data[i]); - } + memcpy(dst, default_data, default_length); } return; } @@ -2046,9 +2023,7 @@ __global__ void copy_scalar_string_data_kernel( auto row_start = row_offsets[row] - row_base_offset; uint8_t const* src = message_data + row_start + loc.offset; - for (int i = 0; i < loc.length; i++) { - dst[i] = static_cast(src[i]); - } + memcpy(dst, src, loc.length); } // ============================================================================ @@ -2121,10 +2096,10 @@ std::unique_ptr make_null_column(cudf::data_type dtype, mr); } case cudf::type_id::STRUCT: { - // Create STRUCT with all nulls and no children - // Note: This is a workaround. Proper nested struct handling requires recursive processing - // with full schema information. An empty struct with no children won't match expected - // schema for deeply nested types, but prevents crashes for unprocessed struct fields. + // TODO(protobuf): This creates an empty STRUCT with no children, which does not + // match the expected nested schema. This is a crash-prevention workaround for + // unprocessed struct fields at deep nesting levels. A proper fix would recurse + // into the schema to build the correct child column structure with all-null leaves. std::vector> empty_children; auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); return cudf::make_structs_column( @@ -2255,7 +2230,7 @@ __global__ void check_required_fields_kernel( for (int f = 0; f < num_fields; f++) { if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { // Required field is missing - set error flag - atomicExch(error_flag, 1); + atomicExch(error_flag, ERR_REQUIRED); return; // No need to check other fields for this row } } @@ -2271,6 +2246,8 @@ __global__ void check_required_fields_kernel( * encountered, the entire struct row is set to null (not just the enum field). * * The valid_values array must be sorted for binary search. + * + * @note Time complexity: O(log(num_valid_values)) per row. */ __global__ void validate_enum_values_kernel( int32_t const* values, // [num_rows] extracted enum values @@ -2476,7 +2453,7 @@ std::unique_ptr build_repeated_scalar_column( rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; int encoding = field_desc.encoding; @@ -2592,7 +2569,7 @@ std::unique_ptr build_repeated_string_column( // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; extract_repeated_lengths_kernel<<>>( d_occurrences.data(), total_count, str_lengths.data()); @@ -2787,7 +2764,7 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector d_msg_locs(total_count, stream, mr); rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); { - auto const occ_threads = 256; + auto const occ_threads = THREADS_PER_BLOCK; auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; compute_msg_locations_from_occurrences_kernel<<>>( d_occurrences.data(), list_offsets, base_offset, @@ -2799,7 +2776,7 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; // Use a custom kernel to scan child fields within message occurrences @@ -3030,7 +3007,7 @@ std::unique_ptr build_nested_struct_column( return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = (num_rows + threads - 1) / threads; int num_child_fields = static_cast(child_field_indices.size()); @@ -3113,12 +3090,6 @@ std::unique_ptr build_nested_struct_column( struct_children.push_back(cudf::make_lists_column( num_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr)); } else { - rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), - d_rep_occs.data(), d_error.data()); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), d_rep_counts.begin(), d_rep_counts.end(), @@ -3126,38 +3097,49 @@ std::unique_ptr build_nested_struct_column( CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_rep_count, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + scan_repeated_in_nested_kernel<<>>( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + d_rep_schema.data(), 1, list_offs.data(), 1, d_rep_indices.data(), + d_rep_occs.data(), d_error.data()); + std::unique_ptr child_values; if (elem_type_id == cudf::type_id::INT32) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_varint_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); child_values = std::make_unique( cudf::data_type{cudf::type_id::INT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); } else if (elem_type_id == cudf::type_id::INT64) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_varint_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); child_values = std::make_unique( cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); } else if (elem_type_id == cudf::type_id::BOOL8) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_varint_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); child_values = std::make_unique( cudf::data_type{cudf::type_id::BOOL8}, total_rep_count, values.release(), rmm::device_buffer{}, 0); } else if (elem_type_id == cudf::type_id::FLOAT32) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_fixed_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_fixed_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); child_values = std::make_unique( cudf::data_type{cudf::type_id::FLOAT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); } else if (elem_type_id == cudf::type_id::FLOAT64) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_fixed_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_fixed_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); child_values = std::make_unique( @@ -3180,7 +3162,8 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - extract_repeated_in_nested_string_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_string_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, str_offs.data(), chars.data(), d_error.data()); } @@ -3206,7 +3189,8 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector bytes(total_bytes, stream, mr); if (total_bytes > 0) { - extract_repeated_in_nested_string_kernel<<<(total_rep_count + 255) / 256, 256, 0, stream.value()>>>( + extract_repeated_in_nested_string_kernel<<< + (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( message_data, list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data(), total_rep_count, byte_offs.data(), bytes.data(), d_error.data()); } @@ -3229,8 +3213,8 @@ std::unique_ptr build_nested_struct_column( } else { rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); - auto const rep_blk = (total_rep_count + 255) / 256; - compute_virtual_parents_for_nested_repeated_kernel<<>>( + auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + compute_virtual_parents_for_nested_repeated_kernel<<>>( d_rep_occs.data(), list_offsets, d_parent_locs.data(), d_virtual_row_offsets.data(), d_virtual_parent_locs.data(), total_rep_count); @@ -3741,7 +3725,7 @@ std::unique_ptr decode_protobuf_to_struct( num_rows * sizeof(bool), stream.value())); } - auto const threads = 256; + auto const threads = THREADS_PER_BLOCK; auto const blocks = static_cast((num_rows + threads - 1) / threads); // Allocate for counting repeated fields @@ -4294,8 +4278,8 @@ std::unique_ptr decode_protobuf_to_struct( // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); - auto const rep_blocks = static_cast((total_count + 255) / 256); - extract_repeated_varint_kernel<<>>( + auto const rep_blocks = static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); + extract_repeated_varint_kernel<<>>( msg_data, loffs, boff, d_occurrences.data(), total_count, enum_ints.data(), d_error.data()); @@ -4336,7 +4320,7 @@ std::unique_ptr decode_protobuf_to_struct( // 4. Compute per-element string lengths rmm::device_uvector elem_lengths(total_count, stream, mr); - compute_enum_string_lengths_kernel<<>>( + compute_enum_string_lengths_kernel<<>>( enum_ints.data(), elem_valid.data(), d_valid_enums.data(), d_name_offsets.data(), static_cast(valid_enums.size()), elem_lengths.data(), total_count); @@ -4361,7 +4345,7 @@ std::unique_ptr decode_protobuf_to_struct( // 6. Copy string chars rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( + copy_enum_string_chars_kernel<<>>( enum_ints.data(), elem_valid.data(), d_valid_enums.data(), d_name_offsets.data(), d_name_chars.data(), static_cast(valid_enums.size()), str_offsets.data(), chars.data(), total_count); From fd7ec66245f3fec3b386b17c7a01e5d833423972 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 25 Feb 2026 11:32:19 +0800 Subject: [PATCH 024/107] style change Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 94 +- src/main/cpp/src/protobuf.cu | 3604 ++++++++++++++++++++---------- src/main/cpp/src/protobuf.hpp | 22 +- 3 files changed, 2438 insertions(+), 1282 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 8fb6ac81a6..d76d58f59e 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -25,24 +25,24 @@ extern "C" { JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, - jclass, - jlong binary_input_view, - jintArray field_numbers, - jintArray parent_indices, - jintArray depth_levels, - jintArray wire_types, - jintArray output_type_ids, - jintArray encodings, - jbooleanArray is_repeated, - jbooleanArray is_required, - jbooleanArray has_default_value, - jlongArray default_ints, - jdoubleArray default_floats, - jbooleanArray default_bools, - jobjectArray default_strings, - jobjectArray enum_valid_values, - jobjectArray enum_names, - jboolean fail_on_errors) + jclass, + jlong binary_input_view, + jintArray field_numbers, + jintArray parent_indices, + jintArray depth_levels, + jintArray wire_types, + jintArray output_type_ids, + jintArray encodings, + jbooleanArray is_repeated, + jbooleanArray is_required, + jbooleanArray has_default_value, + jlongArray default_ints, + jdoubleArray default_floats, + jbooleanArray default_bools, + jobjectArray default_strings, + jobjectArray enum_valid_values, + jobjectArray enum_names, + jboolean fail_on_errors) { JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); @@ -82,16 +82,11 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, int num_fields = n_field_numbers.size(); // Validate array sizes - if (n_parent_indices.size() != num_fields || - n_depth_levels.size() != num_fields || - n_wire_types.size() != num_fields || - n_output_type_ids.size() != num_fields || - n_encodings.size() != num_fields || - n_is_repeated.size() != num_fields || - n_is_required.size() != num_fields || - n_has_default.size() != num_fields || - n_default_ints.size() != num_fields || - n_default_floats.size() != num_fields || + if (n_parent_indices.size() != num_fields || n_depth_levels.size() != num_fields || + n_wire_types.size() != num_fields || n_output_type_ids.size() != num_fields || + n_encodings.size() != num_fields || n_is_repeated.size() != num_fields || + n_is_required.size() != num_fields || n_has_default.size() != num_fields || + n_default_ints.size() != num_fields || n_default_floats.size() != num_fields || n_default_bools.size() != num_fields) { JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, @@ -103,17 +98,15 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, std::vector schema; schema.reserve(num_fields); for (int i = 0; i < num_fields; ++i) { - schema.push_back({ - n_field_numbers[i], - n_parent_indices[i], - n_depth_levels[i], - n_wire_types[i], - static_cast(n_output_type_ids[i]), - n_encodings[i], - n_is_repeated[i] != 0, - n_is_required[i] != 0, - n_has_default[i] != 0 - }); + schema.push_back({n_field_numbers[i], + n_parent_indices[i], + n_depth_levels[i], + n_wire_types[i], + static_cast(n_output_type_ids[i]), + n_encodings[i], + n_is_repeated[i] != 0, + n_is_required[i] != 0, + n_has_default[i] != 0}); } // Build output types @@ -183,7 +176,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, if (name_bytes == nullptr) { names_for_field.emplace_back(); } else { - jsize len = env->GetArrayLength(name_bytes); + jsize len = env->GetArrayLength(name_bytes); jbyte* bytes = env->GetByteArrayElements(name_bytes, nullptr); names_for_field.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); @@ -194,17 +187,16 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } } - auto result = spark_rapids_jni::decode_protobuf_to_struct( - *input, - schema, - schema_output_types, - default_int_values, - default_float_values, - default_bool_values, - default_string_values, - enum_values, - enum_name_values, - fail_on_errors); + auto result = spark_rapids_jni::decode_protobuf_to_struct(*input, + schema, + schema_output_types, + default_int_values, + default_float_values, + default_bool_values, + default_string_values, + enum_values, + enum_name_values, + fail_on_errors); return cudf::jni::release_as_jlong(result); } diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index df6c1c6b36..a678a9b332 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -92,17 +92,17 @@ struct field_descriptor { * Information about repeated field occurrences in a row. */ struct repeated_field_info { - int32_t count; // Number of occurrences in this row - int32_t total_length; // Total bytes for all occurrences (for varlen fields) + int32_t count; // Number of occurrences in this row + int32_t total_length; // Total bytes for all occurrences (for varlen fields) }; /** * Location of a single occurrence of a repeated field. */ struct repeated_occurrence { - int32_t row_idx; // Which row this occurrence belongs to - int32_t offset; // Offset within the message - int32_t length; // Length of the field data + int32_t row_idx; // Which row this occurrence belongs to + int32_t offset; // Offset within the message + int32_t length; // Length of the field data }; /** @@ -197,11 +197,8 @@ __device__ inline bool skip_field(uint8_t const* cur, * Get the data offset and length for a field at current position. * Returns true on success, false on error. */ -__device__ inline bool get_field_data_location(uint8_t const* cur, - uint8_t const* end, - int wt, - int32_t& data_offset, - int32_t& data_length) +__device__ inline bool get_field_data_location( + uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) { if (wt == WT_LEN) { // For length-delimited, read the length prefix @@ -324,8 +321,8 @@ __global__ void scan_all_fields_kernel( auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } @@ -403,13 +400,14 @@ __global__ void count_repeated_fields_kernel( cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, int num_fields, - int depth_level, // Which depth level we're processing - repeated_field_info* repeated_info, // [num_rows * num_repeated_fields_at_this_depth] - int num_repeated_fields, // Number of repeated fields at this depth - int const* repeated_field_indices, // Indices into schema for repeated fields at this depth - field_location* nested_locations, // Locations of nested messages for next depth [num_rows * num_nested] - int num_nested_fields, // Number of nested message fields at this depth - int const* nested_field_indices, // Indices into schema for nested message fields + int depth_level, // Which depth level we're processing + repeated_field_info* repeated_info, // [num_rows * num_repeated_fields_at_this_depth] + int num_repeated_fields, // Number of repeated fields at this depth + int const* repeated_field_indices, // Indices into schema for repeated fields at this depth + field_location* + nested_locations, // Locations of nested messages for next depth [num_rows * num_nested] + int num_nested_fields, // Number of nested message fields at this depth + int const* nested_field_indices, // Indices into schema for nested message fields int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -426,15 +424,13 @@ __global__ void count_repeated_fields_kernel( nested_locations[row * num_nested_fields + f] = {-1, 0}; } - if (in.nullable() && in.is_null(row)) { - return; - } + if (in.nullable() && in.is_null(row)) { return; } auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } uint8_t const* cur = bytes + start; @@ -451,16 +447,16 @@ __global__ void count_repeated_fields_kernel( int schema_idx = repeated_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { int expected_wt = schema[schema_idx].wire_type; - + // Handle both packed and unpacked encoding for repeated fields // Packed encoding uses wire type LEN (2) even for scalar types bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - + if (!is_packed && wt != expected_wt) { atomicExch(error_flag, ERR_WIRE_TYPE); return; } - + if (is_packed) { // Packed encoding: read length, then count elements inside uint64_t packed_len; @@ -469,15 +465,15 @@ __global__ void count_repeated_fields_kernel( atomicExch(error_flag, ERR_VARINT); return; } - + // Count elements based on type uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; + uint8_t const* packed_end = packed_start + packed_len; if (packed_end > msg_end) { atomicExch(error_flag, ERR_OVERFLOW); return; } - + int count = 0; if (expected_wt == WT_VARINT) { // Count varints in the packed data @@ -497,9 +493,10 @@ __global__ void count_repeated_fields_kernel( } else if (expected_wt == WT_64BIT) { count = static_cast(packed_len) / 8; } - + repeated_info[row * num_repeated_fields + i].count += count; - repeated_info[row * num_repeated_fields + i].total_length += static_cast(packed_len); + repeated_info[row * num_repeated_fields + i].total_length += + static_cast(packed_len); } else { // Non-packed encoding: single element int32_t data_offset, data_length; @@ -507,7 +504,7 @@ __global__ void count_repeated_fields_kernel( atomicExch(error_flag, ERR_FIELD_SIZE); return; } - + repeated_info[row * num_repeated_fields + i].count++; repeated_info[row * num_repeated_fields + i].total_length += data_length; } @@ -522,14 +519,14 @@ __global__ void count_repeated_fields_kernel( atomicExch(error_flag, ERR_WIRE_TYPE); return; } - + uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { atomicExch(error_flag, ERR_VARINT); return; } - + int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; } @@ -554,25 +551,23 @@ __global__ void count_repeated_fields_kernel( __global__ void scan_repeated_field_occurrences_kernel( cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, - int schema_idx, // Which field in schema we're scanning + int schema_idx, // Which field in schema we're scanning int depth_level, - int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] - repeated_occurrence* occurrences, // Output: all occurrences [total_count] + int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] + repeated_occurrence* occurrences, // Output: all occurrences [total_count] int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; if (row >= in.size()) return; - if (in.nullable() && in.is_null(row)) { - return; - } + if (in.nullable() && in.is_null(row)) { return; } auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } uint8_t const* cur = bytes + start; @@ -591,7 +586,7 @@ __global__ void scan_repeated_field_occurrences_kernel( if (fn == target_fn) { // Check for packed encoding: wire type LEN but expected non-LEN bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); - + if (is_packed) { // Packed encoding: multiple elements in a length-delimited blob uint64_t packed_len; @@ -600,14 +595,14 @@ __global__ void scan_repeated_field_occurrences_kernel( atomicExch(error_flag, ERR_VARINT); return; } - + uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; + uint8_t const* packed_end = packed_start + packed_len; if (packed_end > msg_end) { atomicExch(error_flag, ERR_OVERFLOW); return; } - + // Record each element in the packed blob if (target_wt == WT_VARINT) { // Varints: parse each one @@ -628,7 +623,7 @@ __global__ void scan_repeated_field_occurrences_kernel( // Fixed 32-bit: each element is 4 bytes uint8_t const* p = packed_start; while (p + 4 <= packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); + int32_t elem_offset = static_cast(p - bytes - start); occurrences[write_idx] = {static_cast(row), elem_offset, 4}; write_idx++; p += 4; @@ -637,7 +632,7 @@ __global__ void scan_repeated_field_occurrences_kernel( // Fixed 64-bit: each element is 8 bytes uint8_t const* p = packed_start; while (p + 8 <= packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); + int32_t elem_offset = static_cast(p - bytes - start); occurrences[write_idx] = {static_cast(row), elem_offset, 8}; write_idx++; p += 8; @@ -650,8 +645,8 @@ __global__ void scan_repeated_field_occurrences_kernel( atomicExch(error_flag, ERR_FIELD_SIZE); return; } - - int32_t abs_offset = static_cast(cur - bytes - start) + data_offset; + + int32_t abs_offset = static_cast(cur - bytes - start) + data_offset; occurrences[write_idx] = {static_cast(row), abs_offset, data_length}; write_idx++; } @@ -738,7 +733,7 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, bool* valid, int num_rows, int* error_flag, - bool has_default = false, + bool has_default = false, OutputType default_value = OutputType{}) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -790,21 +785,20 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, * Extract repeated varint values using pre-recorded occurrences. */ template -__global__ void extract_repeated_varint_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - repeated_occurrence const* occurrences, - int total_occurrences, - OutputType* out, - int* error_flag) +__global__ void extract_repeated_varint_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + repeated_occurrence const* occurrences, + int total_occurrences, + OutputType* out, + int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_occurrences) return; - auto const& occ = occurrences[idx]; - auto row_start = row_offsets[occ.row_idx] - base_offset; - uint8_t const* cur = message_data + row_start + occ.offset; + auto const& occ = occurrences[idx]; + auto row_start = row_offsets[occ.row_idx] - base_offset; + uint8_t const* cur = message_data + row_start + occ.offset; uint8_t const* cur_end = cur + occ.length; uint64_t v; @@ -823,20 +817,19 @@ __global__ void extract_repeated_varint_kernel( * Extract repeated fixed-size values using pre-recorded occurrences. */ template -__global__ void extract_repeated_fixed_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - repeated_occurrence const* occurrences, - int total_occurrences, - OutputType* out, - int* error_flag) +__global__ void extract_repeated_fixed_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + repeated_occurrence const* occurrences, + int total_occurrences, + OutputType* out, + int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_occurrences) return; - auto const& occ = occurrences[idx]; - auto row_start = row_offsets[occ.row_idx] - base_offset; + auto const& occ = occurrences[idx]; + auto row_start = row_offsets[occ.row_idx] - base_offset; uint8_t const* cur = message_data + row_start + occ.offset; OutputType value; @@ -879,9 +872,9 @@ __global__ void copy_repeated_varlen_data_kernel( auto const& occ = occurrences[idx]; if (occ.length == 0) return; - auto row_start = row_offsets[occ.row_idx] - base_offset; + auto row_start = row_offsets[occ.row_idx] - base_offset; uint8_t const* src = message_data + row_start + occ.offset; - char* dst = output_data + output_offsets[idx]; + char* dst = output_data + output_offsets[idx]; memcpy(dst, src, occ.length); } @@ -889,10 +882,9 @@ __global__ void copy_repeated_varlen_data_kernel( /** * Extract lengths from repeated occurrences for prefix sum. */ -__global__ void extract_repeated_lengths_kernel( - repeated_occurrence const* occurrences, - int total_occurrences, - int32_t* lengths) +__global__ void extract_repeated_lengths_kernel(repeated_occurrence const* occurrences, + int total_occurrences, + int32_t* lengths) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_occurrences) return; @@ -909,16 +901,15 @@ __global__ void extract_repeated_lengths_kernel( * Each row represents a nested message at a specific parent location. * This kernel finds fields within the nested message bytes. */ -__global__ void scan_nested_message_fields_kernel( - uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - int num_parent_rows, - field_descriptor const* field_descs, - int num_fields, - field_location* output_locations, - int* error_flag) +__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_parent_rows) return; @@ -928,13 +919,11 @@ __global__ void scan_nested_message_fields_kernel( } auto const& parent_loc = parent_locations[row]; - if (parent_loc.offset < 0) { - return; - } + if (parent_loc.offset < 0) { return; } - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; uint8_t const* nested_start = message_data + parent_row_start + parent_loc.offset; - uint8_t const* nested_end = nested_start + parent_loc.length; + uint8_t const* nested_end = nested_start + parent_loc.length; uint8_t const* cur = nested_start; @@ -965,7 +954,8 @@ __global__ void scan_nested_message_fields_kernel( atomicExch(error_flag, ERR_OVERFLOW); return; } - output_locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; + output_locations[row * num_fields + f] = {data_offset + len_bytes, + static_cast(len)}; } else { int field_size = get_wire_type_size(wt, cur, nested_end); if (field_size < 0) { @@ -999,7 +989,9 @@ inline std::pair make_null_mask_from_valid( { auto begin = thrust::make_counting_iterator(0); auto end = begin + valid.size(); - auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { return static_cast(ptr[i]); }; + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { + return static_cast(ptr[i]); + }; return cudf::detail::valid_if(begin, end, pred, stream, mr); } @@ -1009,12 +1001,13 @@ inline std::pair make_null_mask_from_valid( */ __global__ void scan_repeated_message_children_kernel( uint8_t const* message_data, - int32_t const* msg_row_offsets, // Row offset for each occurrence - field_location const* msg_locs, // Location of each message occurrence (offset within row, length) + int32_t const* msg_row_offsets, // Row offset for each occurrence + field_location const* + msg_locs, // Location of each message occurrence (offset within row, length) int num_occurrences, field_descriptor const* child_descs, int num_child_fields, - field_location* child_locs, // Output: [num_occurrences * num_child_fields] + field_location* child_locs, // Output: [num_occurrences * num_child_fields] int* error_flag) { auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -1029,9 +1022,9 @@ __global__ void scan_repeated_message_children_kernel( if (msg_loc.offset < 0) return; // Calculate absolute position of this message in the data - int32_t row_offset = msg_row_offsets[occ_idx]; + int32_t row_offset = msg_row_offsets[occ_idx]; uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; - uint8_t const* msg_end = msg_start + msg_loc.length; + uint8_t const* msg_end = msg_start + msg_loc.length; uint8_t const* cur = msg_start; @@ -1060,16 +1053,15 @@ __global__ void scan_repeated_message_children_kernel( return; } // Store offset (after length prefix) and length - child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, static_cast(len)}; + child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, + static_cast(len)}; } else { // For varint/fixed types, store offset and estimated length int32_t data_length = 0; if (wt == WT_VARINT) { uint64_t dummy; int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { - data_length = vbytes; - } + if (read_varint(cur, msg_end, dummy, vbytes)) { data_length = vbytes; } } else if (wt == WT_32BIT) { data_length = 4; } else if (wt == WT_64BIT) { @@ -1095,18 +1087,17 @@ __global__ void scan_repeated_message_children_kernel( * Count repeated field occurrences within nested messages. * Similar to count_repeated_fields_kernel but operates on nested message locations. */ -__global__ void count_repeated_in_nested_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - repeated_field_info* repeated_info, - int num_repeated, - int const* repeated_indices, - int* error_flag) +__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -1121,10 +1112,10 @@ __global__ void count_repeated_in_nested_kernel( cudf::size_type row_off; row_off = row_offsets[row] - base_offset; - + uint8_t const* msg_start = message_data + row_off + parent_loc.offset; - uint8_t const* msg_end = msg_start + parent_loc.length; - uint8_t const* cur = msg_start; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; while (cur < msg_end) { proto_tag tag; @@ -1163,19 +1154,18 @@ __global__ void count_repeated_in_nested_kernel( /** * Scan for repeated field occurrences within nested messages. */ -__global__ void scan_repeated_in_nested_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - int32_t const* occ_prefix_sums, - int num_repeated, - int const* repeated_indices, - repeated_occurrence* occurrences, - int* error_flag) +__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + int32_t const* occ_prefix_sums, + int num_repeated, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -1187,10 +1177,10 @@ __global__ void scan_repeated_in_nested_kernel( int occ_offset = occ_prefix_sums[row]; cudf::size_type row_off = row_offsets[row] - base_offset; - + uint8_t const* msg_start = message_data + row_off + parent_loc.offset; - uint8_t const* msg_end = msg_start + parent_loc.length; - uint8_t const* cur = msg_start; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; int occ_idx = 0; @@ -1205,7 +1195,7 @@ __global__ void scan_repeated_in_nested_kernel( int schema_idx = repeated_indices[ri]; if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { int32_t data_offset = static_cast(cur - msg_start); - int32_t data_len = 0; + int32_t data_len = 0; if (wt == WT_LEN) { uint64_t len; @@ -1219,9 +1209,7 @@ __global__ void scan_repeated_in_nested_kernel( } else if (wt == WT_VARINT) { uint64_t dummy; int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { - data_len = vbytes; - } + if (read_varint(cur, msg_end, dummy, vbytes)) { data_len = vbytes; } } else if (wt == WT_32BIT) { data_len = 4; } else if (wt == WT_64BIT) { @@ -1246,22 +1234,21 @@ __global__ void scan_repeated_in_nested_kernel( * Extract varint values from repeated field occurrences within nested messages. */ template -__global__ void extract_repeated_in_nested_varint_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - repeated_occurrence const* occurrences, - int total_count, - OutputType* out, - int* error_flag) +__global__ void extract_repeated_in_nested_varint_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + repeated_occurrence const* occurrences, + int total_count, + OutputType* out, + int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_count) return; - auto const& occ = occurrences[idx]; + auto const& occ = occurrences[idx]; auto const& parent_loc = parent_locs[occ.row_idx]; - + cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; uint8_t const* msg_end = message_data + row_off + parent_loc.offset + parent_loc.length; @@ -1281,20 +1268,19 @@ __global__ void extract_repeated_in_nested_varint_kernel( } template -__global__ void extract_repeated_in_nested_fixed_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - repeated_occurrence const* occurrences, - int total_count, - OutputType* out, - int* error_flag) +__global__ void extract_repeated_in_nested_fixed_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + repeated_occurrence const* occurrences, + int total_count, + OutputType* out, + int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_count) return; - auto const& occ = occurrences[idx]; + auto const& occ = occurrences[idx]; auto const& parent_loc = parent_locs[occ.row_idx]; cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; @@ -1322,26 +1308,25 @@ __global__ void extract_repeated_in_nested_fixed_kernel( /** * Extract string values from repeated field occurrences within nested messages. */ -__global__ void extract_repeated_in_nested_string_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - repeated_occurrence const* occurrences, - int total_count, - int32_t const* str_offsets, - char* chars, - int* error_flag) +__global__ void extract_repeated_in_nested_string_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + repeated_occurrence const* occurrences, + int total_count, + int32_t const* str_offsets, + char* chars, + int* error_flag) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_count) return; - auto const& occ = occurrences[idx]; + auto const& occ = occurrences[idx]; auto const& parent_loc = parent_locs[occ.row_idx]; - + cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; - + int32_t out_offset = str_offsets[idx]; memcpy(chars + out_offset, data_ptr, occ.length); } @@ -1350,29 +1335,28 @@ __global__ void extract_repeated_in_nested_string_kernel( * Extract varint child fields from repeated message occurrences. */ template -__global__ void extract_repeated_msg_child_varint_kernel( - uint8_t const* message_data, - int32_t const* msg_row_offsets, - field_location const* msg_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - OutputType* out, - bool* valid, - int num_occurrences, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) +__global__ void extract_repeated_msg_child_varint_kernel(uint8_t const* message_data, + int32_t const* msg_row_offsets, + field_location const* msg_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + OutputType* out, + bool* valid, + int num_occurrences, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= num_occurrences) return; - auto const& msg_loc = msg_locs[idx]; + auto const& msg_loc = msg_locs[idx]; auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; if (msg_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[idx] = static_cast(default_value); + out[idx] = static_cast(default_value); valid[idx] = true; } else { valid[idx] = false; @@ -1380,10 +1364,10 @@ __global__ void extract_repeated_msg_child_varint_kernel( return; } - int32_t row_offset = msg_row_offsets[idx]; + int32_t row_offset = msg_row_offsets[idx]; uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; - uint8_t const* cur = msg_start + field_loc.offset; - uint8_t const* msg_end = msg_start + msg_loc.length; + uint8_t const* cur = msg_start + field_loc.offset; + uint8_t const* msg_end = msg_start + msg_loc.length; uint8_t const* varint_end = (cur + MAX_VARINT_BYTES < msg_end) ? (cur + MAX_VARINT_BYTES) : msg_end; @@ -1397,7 +1381,7 @@ __global__ void extract_repeated_msg_child_varint_kernel( if constexpr (ZigZag) { val = (val >> 1) ^ (-(val & 1)); } - out[idx] = static_cast(val); + out[idx] = static_cast(val); valid[idx] = true; } @@ -1405,29 +1389,28 @@ __global__ void extract_repeated_msg_child_varint_kernel( * Extract fixed-size child fields from repeated message occurrences. */ template -__global__ void extract_repeated_msg_child_fixed_kernel( - uint8_t const* message_data, - int32_t const* msg_row_offsets, - field_location const* msg_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - OutputType* out, - bool* valid, - int num_occurrences, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) +__global__ void extract_repeated_msg_child_fixed_kernel(uint8_t const* message_data, + int32_t const* msg_row_offsets, + field_location const* msg_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + OutputType* out, + bool* valid, + int num_occurrences, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= num_occurrences) return; - auto const& msg_loc = msg_locs[idx]; + auto const& msg_loc = msg_locs[idx]; auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; if (msg_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[idx] = default_value; + out[idx] = default_value; valid[idx] = true; } else { valid[idx] = false; @@ -1435,9 +1418,9 @@ __global__ void extract_repeated_msg_child_fixed_kernel( return; } - int32_t row_offset = msg_row_offsets[idx]; + int32_t row_offset = msg_row_offsets[idx]; uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; - uint8_t const* cur = msg_start + field_loc.offset; + uint8_t const* cur = msg_start + field_loc.offset; OutputType value; if constexpr (WT == WT_32BIT) { @@ -1448,7 +1431,7 @@ __global__ void extract_repeated_msg_child_fixed_kernel( memcpy(&value, &raw, sizeof(value)); } - out[idx] = value; + out[idx] = value; valid[idx] = true; } @@ -1472,19 +1455,19 @@ __global__ void extract_repeated_msg_child_strings_kernel( if (idx >= total_count) return; auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; - + if (field_loc.offset < 0 || field_loc.length == 0) { valid[idx] = false; return; } - + valid[idx] = true; - - int32_t row_offset = msg_row_offsets[idx]; - int32_t msg_offset = msg_locs[idx].offset; + + int32_t row_offset = msg_row_offsets[idx]; + int32_t msg_offset = msg_locs[idx].offset; uint8_t const* str_src = message_data + row_offset + msg_offset + field_loc.offset; - char* str_dst = output_chars + string_offsets[idx]; - + char* str_dst = output_chars + string_offsets[idx]; + // Copy string data memcpy(str_dst, str_src, field_loc.length); } @@ -1492,18 +1475,17 @@ __global__ void extract_repeated_msg_child_strings_kernel( /** * Kernel to compute string lengths from child field locations. */ -__global__ void compute_string_lengths_kernel( - field_location const* child_locs, - int child_idx, - int num_child_fields, - int32_t* lengths, - int total_count) +__global__ void compute_string_lengths_kernel(field_location const* child_locs, + int child_idx, + int num_child_fields, + int32_t* lengths, + int total_count) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; auto const& loc = child_locs[idx * num_child_fields + child_idx]; - lengths[idx] = (loc.offset >= 0) ? loc.length : 0; + lengths[idx] = (loc.offset >= 0) ? loc.length : 0; } /** @@ -1522,12 +1504,10 @@ inline std::unique_ptr build_repeated_msg_child_string_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - if (total_count == 0) { - return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); - } + if (total_count == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); } auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1) / threads; // Compute string lengths on GPU rmm::device_uvector d_lengths(total_count, stream, mr); @@ -1536,23 +1516,31 @@ inline std::unique_ptr build_repeated_msg_child_string_column( // Compute offsets via exclusive scan rmm::device_uvector d_str_offsets(total_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_lengths.begin(), d_lengths.end(), - d_str_offsets.begin(), 0); - + thrust::exclusive_scan( + rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), d_str_offsets.begin(), 0); + // Get total chars count int32_t total_chars = 0; - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, d_str_offsets.data() + total_count - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, d_lengths.data() + total_count - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + d_str_offsets.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + d_lengths.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; - + // Set final offset - CUDF_CUDA_TRY(cudaMemcpyAsync(d_str_offsets.data() + total_count, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_str_offsets.data() + total_count, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Allocate output chars and validity rmm::device_uvector d_chars(total_chars, stream, mr); @@ -1561,25 +1549,36 @@ inline std::unique_ptr build_repeated_msg_child_string_column( // Extract all strings in parallel on GPU (critical performance fix!) if (total_chars > 0) { extract_repeated_msg_child_strings_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), child_idx, num_child_fields, - d_str_offsets.data(), d_chars.data(), d_valid.data(), total_count); + message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields, + d_str_offsets.data(), + d_chars.data(), + d_valid.data(), + total_count); } else { // No strings, just set validity - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_valid.begin(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - return child_locs[idx * ncf + ci].offset >= 0; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.begin(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__( + int idx) { return child_locs[idx * ncf + ci].offset >= 0; }); } auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); - auto str_offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_count + 1, d_str_offsets.release(), rmm::device_buffer{}, 0); - return cudf::make_strings_column(total_count, std::move(str_offsets_col), d_chars.release(), null_count, std::move(mask)); + auto str_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + d_str_offsets.release(), + rmm::device_buffer{}, + 0); + return cudf::make_strings_column( + total_count, std::move(str_offsets_col), d_chars.release(), null_count, std::move(mask)); } inline std::unique_ptr build_repeated_msg_child_bytes_column( @@ -1595,68 +1594,99 @@ inline std::unique_ptr build_repeated_msg_child_bytes_column( rmm::device_async_resource_ref mr) { if (total_count == 0) { - auto empty_offsets = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), rmm::device_buffer{}, 0); + auto empty_offsets = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - empty_offsets->mutable_view().data(), &zero, sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(empty_offsets->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); auto empty_bytes = std::make_unique( cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); - return cudf::make_lists_column(0, std::move(empty_offsets), std::move(empty_bytes), - 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column( + 0, std::move(empty_offsets), std::move(empty_bytes), 0, rmm::device_buffer{}, stream, mr); } auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1) / threads; rmm::device_uvector d_lengths(total_count, stream, mr); compute_string_lengths_kernel<<>>( d_child_locs.data(), child_idx, num_child_fields, d_lengths.data(), total_count); rmm::device_uvector d_offs(total_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_lengths.begin(), d_lengths.end(), - d_offs.begin(), 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), d_offs.begin(), 0); int32_t total_bytes = 0; - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, d_offs.data() + total_count - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, d_lengths.data() + total_count - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + int32_t last_len = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, + d_offs.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + d_lengths.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_bytes += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(d_offs.data() + total_count, &total_bytes, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_offs.data() + total_count, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_bytes(total_bytes, stream, mr); rmm::device_uvector d_valid(total_count, stream, mr); if (total_bytes > 0) { extract_repeated_msg_child_strings_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), child_idx, num_child_fields, - d_offs.data(), d_bytes.data(), d_valid.data(), total_count); + message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields, + d_offs.data(), + d_bytes.data(), + d_valid.data(), + total_count); } else { - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_valid.begin(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - return child_locs[idx * ncf + ci].offset >= 0; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.begin(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__( + int idx) { return child_locs[idx * ncf + ci].offset >= 0; }); } auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); - auto offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_count + 1, d_offs.release(), rmm::device_buffer{}, 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, total_bytes, - rmm::device_buffer(d_bytes.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); - return cudf::make_lists_column(total_count, std::move(offs_col), std::move(bytes_child), - null_count, std::move(mask), stream, mr); + auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + d_offs.release(), + rmm::device_buffer{}, + 0); + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(d_bytes.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column(total_count, + std::move(offs_col), + std::move(bytes_child), + null_count, + std::move(mask), + stream, + mr); } /** @@ -1665,18 +1695,18 @@ inline std::unique_ptr build_repeated_msg_child_bytes_column( * This is a critical performance optimization. */ __global__ void compute_nested_struct_locations_kernel( - field_location const* child_locs, // Child field locations from parent scan - field_location const* msg_locs, // Parent message locations - int32_t const* msg_row_offsets, // Parent message row offsets - int child_idx, // Which child field is the nested struct - int num_child_fields, // Total number of child fields per occurrence - field_location* nested_locs, // Output: nested struct locations - int32_t* nested_row_offsets, // Output: nested struct row offsets + field_location const* child_locs, // Child field locations from parent scan + field_location const* msg_locs, // Parent message locations + int32_t const* msg_row_offsets, // Parent message row offsets + int child_idx, // Which child field is the nested struct + int num_child_fields, // Total number of child fields per occurrence + field_location* nested_locs, // Output: nested struct locations + int32_t* nested_row_offsets, // Output: nested struct row offsets int total_count) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - + // Get the nested struct location from child_locs nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; // Compute absolute row offset = msg_row_offset + msg_offset @@ -1689,19 +1719,19 @@ __global__ void compute_nested_struct_locations_kernel( * This replaces host-side loop with D->H->D copy pattern. */ __global__ void compute_grandchild_parent_locations_kernel( - field_location const* parent_locs, // Parent locations (row count) - field_location const* child_locs, // Child locations (row * num_child_fields) - int child_idx, // Which child field - int num_child_fields, // Total child fields per row - field_location* gc_parent_abs, // Output: absolute grandchild parent locations + field_location const* parent_locs, // Parent locations (row count) + field_location const* child_locs, // Child locations (row * num_child_fields) + int child_idx, // Which child field + int num_child_fields, // Total child fields per row + field_location* gc_parent_abs, // Output: absolute grandchild parent locations int num_rows) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= num_rows) return; - + auto const& parent_loc = parent_locs[row]; - auto const& child_loc = child_locs[row * num_child_fields + child_idx]; - + auto const& child_loc = child_locs[row * num_child_fields + child_idx]; + if (parent_loc.offset >= 0 && child_loc.offset >= 0) { // Absolute offset = parent offset + child's relative offset gc_parent_abs[row].offset = parent_loc.offset + child_loc.offset; @@ -1718,16 +1748,16 @@ __global__ void compute_grandchild_parent_locations_kernel( */ __global__ void compute_virtual_parents_for_nested_repeated_kernel( repeated_occurrence const* occurrences, - cudf::size_type const* row_list_offsets, // original binary input list offsets - field_location const* parent_locations, // parent nested message locations - cudf::size_type* virtual_row_offsets, // output: [total_count] - field_location* virtual_parent_locs, // output: [total_count] + cudf::size_type const* row_list_offsets, // original binary input list offsets + field_location const* parent_locations, // parent nested message locations + cudf::size_type* virtual_row_offsets, // output: [total_count] + field_location* virtual_parent_locs, // output: [total_count] int total_count) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - auto const& occ = occurrences[idx]; + auto const& occ = occurrences[idx]; auto const& ploc = parent_locations[occ.row_idx]; virtual_row_offsets[idx] = row_list_offsets[occ.row_idx]; @@ -1756,10 +1786,10 @@ __global__ void compute_msg_locations_from_occurrences_kernel( { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - - auto const& occ = occurrences[idx]; + + auto const& occ = occurrences[idx]; msg_row_offsets[idx] = static_cast(list_offsets[occ.row_idx] - base_offset); - msg_locs[idx] = {occ.offset, occ.length}; + msg_locs[idx] = {occ.offset, occ.length}; } /** @@ -1770,40 +1800,37 @@ struct extract_strided_count { repeated_field_info const* info; int field_idx; int num_fields; - - __device__ int32_t operator()(int row) const { - return info[row * num_fields + field_idx].count; - } + + __device__ int32_t operator()(int row) const { return info[row * num_fields + field_idx].count; } }; /** * Extract varint from nested message locations. */ template -__global__ void extract_nested_varint_kernel( - uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - OutputType* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) +__global__ void extract_nested_varint_kernel(uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + OutputType* out, + bool* valid, + int num_rows, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; if (parent_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[row] = static_cast(default_value); + out[row] = static_cast(default_value); valid[row] = true; } else { valid[row] = false; @@ -1811,8 +1838,8 @@ __global__ void extract_nested_varint_kernel( return; } - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; uint8_t const* cur_end = cur + field_loc.length; uint64_t v; @@ -1824,7 +1851,7 @@ __global__ void extract_nested_varint_kernel( } if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[row] = static_cast(v); + out[row] = static_cast(v); valid[row] = true; } @@ -1832,30 +1859,29 @@ __global__ void extract_nested_varint_kernel( * Extract fixed-size from nested message locations. */ template -__global__ void extract_nested_fixed_kernel( - uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - OutputType* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) +__global__ void extract_nested_fixed_kernel(uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + OutputType* out, + bool* valid, + int num_rows, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; if (parent_loc.offset < 0 || field_loc.offset < 0) { if (has_default) { - out[row] = default_value; + out[row] = default_value; valid[row] = true; } else { valid[row] = false; @@ -1864,7 +1890,7 @@ __global__ void extract_nested_fixed_kernel( } auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; + uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; OutputType value; if constexpr (WT == WT_32BIT) { @@ -1885,47 +1911,44 @@ __global__ void extract_nested_fixed_kernel( memcpy(&value, &raw, sizeof(value)); } - out[row] = value; + out[row] = value; valid[row] = true; } /** * Copy nested variable-length data (string/bytes). */ -__global__ void copy_nested_varlen_data_kernel( - uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t const* output_offsets, - char* output_data, - int num_rows, - bool has_default = false, - uint8_t const* default_data = nullptr, - int32_t default_length = 0) +__global__ void copy_nested_varlen_data_kernel(uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t const* output_offsets, + char* output_data, + int num_rows, + bool has_default = false, + uint8_t const* default_data = nullptr, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; char* dst = output_data + output_offsets[row]; if (parent_loc.offset < 0 || field_loc.offset < 0) { - if (has_default && default_length > 0) { - memcpy(dst, default_data, default_length); - } + if (has_default && default_length > 0) { memcpy(dst, default_data, default_length); } return; } if (field_loc.length == 0) return; auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* src = message_data + parent_row_start + parent_loc.offset + field_loc.offset; + uint8_t const* src = message_data + parent_row_start + parent_loc.offset + field_loc.offset; memcpy(dst, src, field_loc.length); } @@ -1933,21 +1956,20 @@ __global__ void copy_nested_varlen_data_kernel( /** * Extract nested field lengths for prefix sum. */ -__global__ void extract_nested_lengths_kernel( - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t* lengths, - int num_rows, - bool has_default = false, - int32_t default_length = 0) +__global__ void extract_nested_lengths_kernel(field_location const* parent_locations, + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t* lengths, + int num_rows, + bool has_default = false, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; + auto const& field_loc = field_locations[row * num_fields + field_idx]; if (parent_loc.offset >= 0 && field_loc.offset >= 0) { lengths[row] = field_loc.length; @@ -1962,14 +1984,13 @@ __global__ void extract_nested_lengths_kernel( * Extract scalar string field lengths for prefix sum. * For top-level STRING fields (not nested within a struct). */ -__global__ void extract_scalar_string_lengths_kernel( - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t* lengths, - int num_rows, - bool has_default = false, - int32_t default_length = 0) +__global__ void extract_scalar_string_lengths_kernel(field_location const* field_locations, + int field_idx, + int num_fields, + int32_t* lengths, + int num_rows, + bool has_default = false, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -1989,19 +2010,18 @@ __global__ void extract_scalar_string_lengths_kernel( * Copy scalar string field data. * For top-level STRING fields (not nested within a struct). */ -__global__ void copy_scalar_string_data_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type row_base_offset, - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t const* output_offsets, - char* output_data, - int num_rows, - bool has_default = false, - uint8_t const* default_data = nullptr, - int32_t default_length = 0) +__global__ void copy_scalar_string_data_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type row_base_offset, + field_location const* field_locations, + int field_idx, + int num_fields, + int32_t const* output_offsets, + char* output_data, + int num_rows, + bool has_default = false, + uint8_t const* default_data = nullptr, + int32_t default_length = 0) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -2012,15 +2032,13 @@ __global__ void copy_scalar_string_data_kernel( if (loc.offset < 0) { // Field not found - use default if available - if (has_default && default_length > 0) { - memcpy(dst, default_data, default_length); - } + if (has_default && default_length > 0) { memcpy(dst, default_data, default_length); } return; } if (loc.length == 0) return; - auto row_start = row_offsets[row] - row_base_offset; + auto row_start = row_offsets[row] - row_base_offset; uint8_t const* src = message_data + row_start + loc.offset; memcpy(dst, src, loc.length); @@ -2030,7 +2048,8 @@ __global__ void copy_scalar_string_data_kernel( // Utility functions // ============================================================================ -// Note: make_null_mask_from_valid is defined earlier in the file (before scan_repeated_message_children_kernel) +// Note: make_null_mask_from_valid is defined earlier in the file (before +// scan_repeated_message_children_kernel) /** * Create an all-null column of the specified type. @@ -2120,13 +2139,19 @@ std::unique_ptr make_empty_column_safe(cudf::data_type dtype, switch (dtype.id()) { case cudf::type_id::LIST: { // Create empty list column with empty UINT8 child (Spark BinaryType maps to LIST) - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, 1, rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, 0); + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); // Initialize offset to 0 int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), &zero, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); auto child_col = std::make_unique( cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); return cudf::make_lists_column( @@ -2147,23 +2172,19 @@ std::unique_ptr make_empty_column_safe(cudf::data_type dtype, /** * Find all child field indices for a given parent index in the schema. * This is a commonly used pattern throughout the codebase. - * - * @param schema The schema vector (either nested_field_descriptor or device_nested_field_descriptor) + * + * @param schema The schema vector (either nested_field_descriptor or + * device_nested_field_descriptor) * @param num_fields Number of fields in the schema * @param parent_idx The parent index to search for * @return Vector of child field indices */ template -std::vector find_child_field_indices( - SchemaT const& schema, - int num_fields, - int parent_idx) +std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) { std::vector child_indices; for (int i = 0; i < num_fields; i++) { - if (schema[i].parent_idx == parent_idx) { - child_indices.push_back(i); - } + if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } } return child_indices; } @@ -2171,7 +2192,7 @@ std::vector find_child_field_indices( /** * Recursively create an empty struct column with proper nested structure based on schema. * This handles STRUCT children that contain their own grandchildren. - * + * * @param schema The schema vector * @param schema_output_types Output types for each schema field * @param parent_idx Index of the parent field (whose children we want to create) @@ -2190,11 +2211,11 @@ std::unique_ptr make_empty_struct_column_with_schema( rmm::device_async_resource_ref mr) { auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); - + std::vector> children; for (int child_idx : child_indices) { auto child_type = schema_output_types[child_idx]; - + // Recursively handle nested struct children if (child_type.id() == cudf::type_id::STRUCT) { children.push_back(make_empty_struct_column_with_schema( @@ -2203,7 +2224,7 @@ std::unique_ptr make_empty_struct_column_with_schema( children.push_back(make_empty_column_safe(child_type, stream, mr)); } } - + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); } @@ -2313,10 +2334,10 @@ __global__ void compute_enum_string_lengths_kernel( } int32_t val = values[row]; - int left = 0; - int right = num_valid_values - 1; + int left = 0; + int right = num_valid_values - 1; while (left <= right) { - int mid = left + (right - left) / 2; + int mid = left + (right - left) / 2; int32_t mid_val = valid_enum_values[mid]; if (mid_val == val) { lengths[row] = enum_name_offsets[mid + 1] - enum_name_offsets[mid]; @@ -2351,14 +2372,14 @@ __global__ void copy_enum_string_chars_kernel( if (!valid[row]) return; int32_t val = values[row]; - int left = 0; - int right = num_valid_values - 1; + int left = 0; + int right = num_valid_values - 1; while (left <= right) { - int mid = left + (right - left) / 2; + int mid = left + (right - left) / 2; int32_t mid_val = valid_enum_values[mid]; if (mid_val == val) { int32_t src_begin = enum_name_offsets[mid]; - int32_t src_end = enum_name_offsets[mid + 1]; + int32_t src_end = enum_name_offsets[mid + 1]; int32_t dst_begin = output_offsets[row]; for (int32_t i = 0; i < (src_end - src_begin); ++i) { out_chars[dst_begin + i] = static_cast(enum_name_chars[src_begin + i]); @@ -2374,7 +2395,6 @@ __global__ void copy_enum_string_chars_kernel( namespace spark_rapids_jni { - namespace { /** @@ -2401,22 +2421,35 @@ std::unique_ptr build_repeated_scalar_column( // All rows have count=0, but we still need to check input nulls rmm::device_uvector offsets(num_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); - auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) - ? cudf::type_id::UINT8 - : static_cast(field_desc.output_type_id); - auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); - + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) + ? cudf::type_id::UINT8 + : static_cast(field_desc.output_type_id); + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + if (input_null_count > 0) { // Copy input null mask - only input nulls produce output nulls auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), - input_null_count, std::move(null_mask), stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); } else { // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), - 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); } } @@ -2424,29 +2457,36 @@ std::unique_ptr build_repeated_scalar_column( auto const* list_offsets = in_list.offsets().data(); cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); // Build list offsets from counts entirely on GPU (performance fix!) // Copy h_repeated_info to device and use thrust::transform to extract counts rmm::device_uvector d_rep_info(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), h_repeated_info.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), + h_repeated_info.data(), num_rows * sizeof(repeated_field_info), - cudaMemcpyHostToDevice, stream.value())); - + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), d_rep_info.end(), + d_rep_info.begin(), + d_rep_info.end(), counts.begin(), [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - + thrust::exclusive_scan( + rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + // Set last offset = total_count - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Extract values rmm::device_uvector values(total_count, stream, mr); @@ -2454,47 +2494,84 @@ std::unique_ptr build_repeated_scalar_column( CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1) / threads; int encoding = field_desc.encoding; - bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); - + bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); + // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) // For integer types, use fixed kernel only if encoding is ENC_FIXED constexpr bool is_floating_point = std::is_same_v || std::is_same_v; - bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); + bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); if (use_fixed_kernel) { if constexpr (sizeof(T) == 4) { - extract_repeated_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + extract_repeated_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_occurrences.data(), + total_count, + values.data(), + d_error.data()); } else { - extract_repeated_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + extract_repeated_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_occurrences.data(), + total_count, + values.data(), + d_error.data()); } } else if (zigzag) { - extract_repeated_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + extract_repeated_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_occurrences.data(), + total_count, + values.data(), + d_error.data()); } else { - extract_repeated_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_occurrences.data(), total_count, values.data(), d_error.data()); + extract_repeated_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_occurrences.data(), + total_count, + values.data(), + d_error.data()); } - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); - auto child_col = std::make_unique( - cudf::data_type{static_cast(field_desc.output_type_id)}, total_count, values.release(), rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto child_col = std::make_unique( + cudf::data_type{static_cast(field_desc.output_type_id)}, + total_count, + values.release(), + rmm::device_buffer{}, + 0); // Only rows where INPUT is null should produce null output // Rows with valid input but count=0 should produce empty array [] if (input_null_count > 0) { // Copy input null mask - only input nulls produce output nulls auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), - input_null_count, std::move(null_mask), stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); } - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); } /** @@ -2520,21 +2597,34 @@ std::unique_ptr build_repeated_string_column( // All rows have count=0, but we still need to check input nulls rmm::device_uvector offsets(num_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); - auto child_col = is_bytes - ? make_empty_column_safe(cudf::data_type{cudf::type_id::LIST}, stream, mr) // LIST - : cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); - + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto child_col = is_bytes ? make_empty_column_safe( + cudf::data_type{cudf::type_id::LIST}, stream, mr) // LIST + : cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + if (input_null_count > 0) { // Copy input null mask - only input nulls produce output nulls auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), - input_null_count, std::move(null_mask), stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); } else { // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), - 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); } } @@ -2543,91 +2633,136 @@ std::unique_ptr build_repeated_string_column( auto const* list_offsets = in_list.offsets().data(); cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); // Build list offsets from counts entirely on GPU (performance fix!) // Copy h_repeated_info to device and use thrust::transform to extract counts rmm::device_uvector d_rep_info(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), h_repeated_info.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), + h_repeated_info.data(), num_rows * sizeof(repeated_field_info), - cudaMemcpyHostToDevice, stream.value())); - + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), d_rep_info.end(), + d_rep_info.begin(), + d_rep_info.end(), counts.begin(), [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - + thrust::exclusive_scan( + rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + // Set last offset = total_count - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1) / threads; extract_repeated_lengths_kernel<<>>( d_occurrences.data(), total_count, str_lengths.data()); // Compute string offsets via prefix sum rmm::device_uvector str_offsets(total_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), str_lengths.begin(), str_lengths.end(), str_offsets.begin(), 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), str_lengths.begin(), str_lengths.end(), str_offsets.begin(), 0); int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, str_offsets.data() + total_count - 1, sizeof(int32_t), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + str_offsets.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, str_lengths.data() + total_count - 1, sizeof(int32_t), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + str_lengths.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, &total_chars, sizeof(int32_t), - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Copy string data rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_repeated_varlen_data_kernel<<>>( - message_data, list_offsets, base_offset, d_occurrences.data(), total_count, - str_offsets.data(), chars.data()); + copy_repeated_varlen_data_kernel<<>>(message_data, + list_offsets, + base_offset, + d_occurrences.data(), + total_count, + str_offsets.data(), + chars.data()); } // Build the child column (either STRING or LIST) std::unique_ptr child_col; if (is_bytes) { // Build LIST for bytes (Spark BinaryType maps to LIST) - auto str_offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), rmm::device_buffer{}, 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, total_chars, - rmm::device_buffer(chars.data(), total_chars, stream, mr), rmm::device_buffer{}, 0); - child_col = cudf::make_lists_column(total_count, std::move(str_offsets_col), std::move(bytes_child), - 0, rmm::device_buffer{}, stream, mr); + auto str_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + str_offsets.release(), + rmm::device_buffer{}, + 0); + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_chars, + rmm::device_buffer(chars.data(), total_chars, stream, mr), + rmm::device_buffer{}, + 0); + child_col = cudf::make_lists_column(total_count, + std::move(str_offsets_col), + std::move(bytes_child), + 0, + rmm::device_buffer{}, + stream, + mr); } else { // Build STRING column - auto str_offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), rmm::device_buffer{}, 0); - child_col = cudf::make_strings_column(total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); + auto str_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + str_offsets.release(), + rmm::device_buffer{}, + 0); + child_col = cudf::make_strings_column( + total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); } - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); // Only rows where INPUT is null should produce null output // Rows with valid input but count=0 should produce empty array [] if (input_null_count > 0) { // Copy input null mask - only input nulls produce output nulls auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), - input_null_count, std::move(null_mask), stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); } - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); } // Forward declaration -- build_nested_struct_column is defined after build_repeated_struct_column @@ -2683,15 +2818,18 @@ std::unique_ptr build_repeated_struct_column( rmm::device_async_resource_ref mr) { auto const input_null_count = binary_input.null_count(); - int num_child_fields = static_cast(child_field_indices.size()); + int num_child_fields = static_cast(child_field_indices.size()); if (total_count == 0 || num_child_fields == 0) { // All rows have count=0 or no child fields - return list of empty structs rmm::device_uvector offsets(num_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); - + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + // Build empty struct child column with proper nested structure int num_schema_fields = static_cast(h_device_schema.size()); std::vector> empty_struct_children; @@ -2705,15 +2843,26 @@ std::unique_ptr build_repeated_struct_column( empty_struct_children.push_back(make_empty_column_safe(child_type, stream, mr)); } } - auto empty_struct = cudf::make_structs_column(0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); - + auto empty_struct = cudf::make_structs_column( + 0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); + if (input_null_count > 0) { auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(empty_struct), - input_null_count, std::move(null_mask), stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(empty_struct), + input_null_count, + std::move(null_mask), + stream, + mr); } else { - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(empty_struct), - 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(empty_struct), + 0, + rmm::device_buffer{}, + stream, + mr); } } @@ -2722,41 +2871,50 @@ std::unique_ptr build_repeated_struct_column( auto const* list_offsets = in_list.offsets().data(); cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); // Build list offsets from counts entirely on GPU (performance fix!) // Copy repeated_info to device and use thrust::transform to extract counts rmm::device_uvector d_rep_info(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), h_repeated_info.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), + h_repeated_info.data(), num_rows * sizeof(repeated_field_info), - cudaMemcpyHostToDevice, stream.value())); - + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), d_rep_info.end(), + d_rep_info.begin(), + d_rep_info.end(), counts.begin(), [] __device__(repeated_field_info const& info) { return info.count; }); - + rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); - + thrust::exclusive_scan( + rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + // Set last offset = total_count (already computed on caller side) - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Build child field descriptors for scanning within each message occurrence std::vector h_child_descs(num_child_fields); for (int ci = 0; ci < num_child_fields; ci++) { - int child_schema_idx = child_field_indices[ci]; - h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; + int child_schema_idx = child_field_indices[ci]; + h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; h_child_descs[ci].expected_wire_type = h_device_schema[child_schema_idx].wire_type; } rmm::device_uvector d_child_descs(num_child_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), h_child_descs.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), + h_child_descs.data(), num_child_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, stream.value())); + cudaMemcpyHostToDevice, + stream.value())); // For each occurrence, we need to scan for child fields // Create "virtual" parent locations from the occurrences using GPU kernel @@ -2765,10 +2923,14 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); { auto const occ_threads = THREADS_PER_BLOCK; - auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; + auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; compute_msg_locations_from_occurrences_kernel<<>>( - d_occurrences.data(), list_offsets, base_offset, - d_msg_locs.data(), d_msg_row_offsets.data(), total_count); + d_occurrences.data(), + list_offsets, + base_offset, + d_msg_locs.data(), + d_msg_row_offsets.data(), + total_count); } // Scan for child fields within each message occurrence @@ -2777,13 +2939,19 @@ std::unique_ptr build_repeated_struct_column( CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1) / threads; // Use a custom kernel to scan child fields within message occurrences // This is similar to scan_nested_message_fields_kernel but operates on occurrences scan_repeated_message_children_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), total_count, - d_child_descs.data(), num_child_fields, d_child_locs.data(), d_error.data()); + message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + total_count, + d_child_descs.data(), + num_child_fields, + d_child_locs.data(), + d_error.data()); // Note: We no longer need to copy child_locs to host because: // 1. All scalar extraction kernels access d_child_locs directly on device @@ -2794,19 +2962,28 @@ std::unique_ptr build_repeated_struct_column( std::vector> struct_children; for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; - auto const enc = h_device_schema[child_schema_idx].encoding; - bool has_def = h_device_schema[child_schema_idx].has_default_value; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = h_device_schema[child_schema_idx].encoding; + bool has_def = h_device_schema[child_schema_idx].has_default_value; switch (dt.id()) { case cudf::type_id::BOOL8: { rmm::device_uvector out(total_count, stream, mr); rmm::device_uvector valid(total_count, stream, mr); int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_val); + extract_repeated_msg_child_varint_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_val); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); struct_children.push_back(std::make_unique( dt, total_count, out.release(), std::move(mask), null_count)); @@ -2817,20 +2994,47 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector valid(total_count, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_int); + extract_repeated_msg_child_varint_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_repeated_msg_child_fixed_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, static_cast(def_int)); + extract_repeated_msg_child_fixed_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_int); + extract_repeated_msg_child_varint_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); struct_children.push_back(std::make_unique( @@ -2842,20 +3046,47 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector valid(total_count, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_int); + extract_repeated_msg_child_varint_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_repeated_msg_child_fixed_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_int); + extract_repeated_msg_child_fixed_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_int); } else { - extract_repeated_msg_child_varint_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_int); + extract_repeated_msg_child_varint_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); struct_children.push_back(std::make_unique( @@ -2866,10 +3097,19 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector out(total_count, stream, mr); rmm::device_uvector valid(total_count, stream, mr); float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - extract_repeated_msg_child_fixed_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_float); + extract_repeated_msg_child_fixed_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_float); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); struct_children.push_back(std::make_unique( dt, total_count, out.release(), std::move(mask), null_count)); @@ -2879,10 +3119,19 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector out(total_count, stream, mr); rmm::device_uvector valid(total_count, stream, mr); double def_double = has_def ? default_floats[child_schema_idx] : 0.0; - extract_repeated_msg_child_fixed_kernel<<>>( - message_data, d_msg_row_offsets.data(), d_msg_locs.data(), - d_child_locs.data(), ci, num_child_fields, out.data(), valid.data(), - total_count, d_error.data(), has_def, def_double); + extract_repeated_msg_child_fixed_kernel + <<>>(message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + total_count, + d_error.data(), + has_def, + def_double); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); struct_children.push_back(std::make_unique( dt, total_count, out.release(), std::move(mask), null_count)); @@ -2890,28 +3139,46 @@ std::unique_ptr build_repeated_struct_column( } case cudf::type_id::STRING: { // For strings, we need a two-pass approach: first get lengths, then copy data - struct_children.push_back( - build_repeated_msg_child_string_column( - message_data, d_msg_row_offsets, d_msg_locs, - d_child_locs, ci, num_child_fields, total_count, d_error, stream, mr)); + struct_children.push_back(build_repeated_msg_child_string_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + stream, + mr)); break; } case cudf::type_id::LIST: { // bytes (BinaryType) child inside repeated message - struct_children.push_back( - build_repeated_msg_child_bytes_column( - message_data, d_msg_row_offsets, d_msg_locs, - d_child_locs, ci, num_child_fields, total_count, d_error, stream, mr)); + struct_children.push_back(build_repeated_msg_child_bytes_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + stream, + mr)); break; } case cudf::type_id::STRUCT: { // Nested struct inside repeated message - use recursive build_nested_struct_column int num_schema_fields = static_cast(h_device_schema.size()); - auto grandchild_indices = find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); - + auto grandchild_indices = + find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); + if (grandchild_indices.empty()) { - struct_children.push_back(cudf::make_structs_column( - total_count, std::vector>{}, 0, rmm::device_buffer{}, stream, mr)); + struct_children.push_back( + cudf::make_structs_column(total_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr)); } else { // Compute virtual parent locations for each occurrence's nested struct child rmm::device_uvector d_nested_locs(total_count, stream, mr); @@ -2920,23 +3187,44 @@ std::unique_ptr build_repeated_struct_column( // Convert int32_t row offsets to cudf::size_type and compute nested struct locations rmm::device_uvector d_nested_row_offsets_i32(total_count, stream, mr); compute_nested_struct_locations_kernel<<>>( - d_child_locs.data(), d_msg_locs.data(), d_msg_row_offsets.data(), - ci, num_child_fields, d_nested_locs.data(), d_nested_row_offsets_i32.data(), total_count); + d_child_locs.data(), + d_msg_locs.data(), + d_msg_row_offsets.data(), + ci, + num_child_fields, + d_nested_locs.data(), + d_nested_row_offsets_i32.data(), + total_count); // Add base_offset back so build_nested_struct_column can subtract it thrust::transform(rmm::exec_policy(stream), - d_nested_row_offsets_i32.begin(), d_nested_row_offsets_i32.end(), + d_nested_row_offsets_i32.begin(), + d_nested_row_offsets_i32.end(), d_nested_row_offsets.begin(), [base_offset] __device__(int32_t v) { return static_cast(v) + base_offset; }); } - struct_children.push_back(build_nested_struct_column( - message_data, d_nested_row_offsets.data(), base_offset, d_nested_locs, - grandchild_indices, schema, num_schema_fields, schema_output_types, - default_ints, default_floats, default_bools, default_strings, - enum_valid_values, enum_names, d_row_has_invalid_enum, d_error_top, - total_count, stream, mr, 0)); + struct_children.push_back(build_nested_struct_column(message_data, + d_nested_row_offsets.data(), + base_offset, + d_nested_locs, + grandchild_indices, + schema, + num_schema_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error_top, + total_count, + stream, + mr, + 0)); } break; } @@ -2948,20 +3236,30 @@ std::unique_ptr build_repeated_struct_column( } // Build the struct column from child columns - auto struct_col = cudf::make_structs_column(total_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); + auto struct_col = cudf::make_structs_column( + total_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); // Build the list offsets column - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); // Build the final LIST column if (input_null_count > 0) { auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), - input_null_count, std::move(null_mask), stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(struct_col), + input_null_count, + std::move(null_mask), + stream, + mr); } - return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); } /** @@ -3004,38 +3302,48 @@ std::unique_ptr build_nested_struct_column( empty_children.push_back(make_empty_column_safe(child_type, stream, mr)); } } - return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (num_rows + threads - 1) / threads; + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (num_rows + threads - 1) / threads; int num_child_fields = static_cast(child_field_indices.size()); std::vector h_child_field_descs(num_child_fields); for (int i = 0; i < num_child_fields; i++) { - int child_idx = child_field_indices[i]; - h_child_field_descs[i].field_number = schema[child_idx].field_number; + int child_idx = child_field_indices[i]; + h_child_field_descs[i].field_number = schema[child_idx].field_number; h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; } rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), h_child_field_descs.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), + h_child_field_descs.data(), num_child_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, stream.value())); + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_child_locations( static_cast(num_rows) * num_child_fields, stream, mr); scan_nested_message_fields_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_child_field_descs.data(), num_child_fields, d_child_locations.data(), d_error.data()); + message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + d_child_field_descs.data(), + num_child_fields, + d_child_locations.data(), + d_error.data()); std::vector> struct_children; for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; - auto const enc = schema[child_schema_idx].encoding; - bool has_def = schema[child_schema_idx].has_default_value; - bool is_repeated = schema[child_schema_idx].is_repeated; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = schema[child_schema_idx].encoding; + bool has_def = schema[child_schema_idx].has_default_value; + bool is_repeated = schema[child_schema_idx].is_repeated; if (is_repeated) { auto elem_type_id = schema[child_schema_idx].output_type; @@ -3043,43 +3351,61 @@ std::unique_ptr build_nested_struct_column( std::vector rep_indices = {0}; rmm::device_uvector d_rep_indices(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), rep_indices.data(), - sizeof(int), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), + rep_indices.data(), + sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); device_nested_field_descriptor rep_desc; - rep_desc.field_number = schema[child_schema_idx].field_number; - rep_desc.wire_type = schema[child_schema_idx].wire_type; - rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); - rep_desc.is_repeated = true; - rep_desc.parent_idx = -1; - rep_desc.depth = 0; - rep_desc.encoding = 0; - rep_desc.is_required = false; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + rep_desc.parent_idx = -1; + rep_desc.depth = 0; + rep_desc.encoding = 0; + rep_desc.is_required = false; rep_desc.has_default_value = false; std::vector h_rep_schema = {rep_desc}; rmm::device_uvector d_rep_schema(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), h_rep_schema.data(), - sizeof(device_nested_field_descriptor), cudaMemcpyHostToDevice, stream.value())); - - count_repeated_in_nested_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_rep_schema.data(), 1, d_rep_info.data(), 1, d_rep_indices.data(), d_error.data()); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), + h_rep_schema.data(), + sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + count_repeated_in_nested_kernel<<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + d_rep_schema.data(), + 1, + d_rep_info.data(), + 1, + d_rep_indices.data(), + d_error.data()); rmm::device_uvector d_rep_counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), d_rep_info.end(), + d_rep_info.begin(), + d_rep_info.end(), d_rep_counts.begin(), [] __device__(repeated_field_info const& info) { return info.count; }); - int total_rep_count = thrust::reduce(rmm::exec_policy(stream), - d_rep_counts.begin(), d_rep_counts.end(), 0); + int total_rep_count = + thrust::reduce(rmm::exec_policy(stream), d_rep_counts.begin(), d_rep_counts.end(), 0); if (total_rep_count == 0) { rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), list_offsets_vec.begin(), list_offsets_vec.end(), 0); - auto list_offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offsets_vec.release(), - rmm::device_buffer{}, 0); + auto list_offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offsets_vec.release(), + rmm::device_buffer{}, + 0); std::unique_ptr child_col; if (elem_type_id == cudf::type_id::STRUCT) { child_col = make_empty_struct_column_with_schema( @@ -3087,151 +3413,296 @@ std::unique_ptr build_nested_struct_column( } else { child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); } - struct_children.push_back(cudf::make_lists_column( - num_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr)); + struct_children.push_back(cudf::make_lists_column(num_rows, + std::move(list_offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr)); } else { rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_rep_counts.begin(), d_rep_counts.end(), - list_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_rep_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_rep_counts.begin(), d_rep_counts.end(), list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_rep_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - d_rep_schema.data(), 1, list_offs.data(), 1, d_rep_indices.data(), - d_rep_occs.data(), d_error.data()); + scan_repeated_in_nested_kernel<<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + d_rep_schema.data(), + 1, + list_offs.data(), + 1, + d_rep_indices.data(), + d_rep_occs.data(), + d_error.data()); std::unique_ptr child_values; if (elem_type_id == cudf::type_id::INT32) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + extract_repeated_in_nested_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + values.data(), + d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); } else if (elem_type_id == cudf::type_id::INT64) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + extract_repeated_in_nested_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + values.data(), + d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); } else if (elem_type_id == cudf::type_id::BOOL8) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::BOOL8}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + extract_repeated_in_nested_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + values.data(), + d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); } else if (elem_type_id == cudf::type_id::FLOAT32) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_fixed_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::FLOAT32}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + extract_repeated_in_nested_fixed_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + values.data(), + d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); } else if (elem_type_id == cudf::type_id::FLOAT64) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_fixed_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, values.data(), d_error.data()); - child_values = std::make_unique( - cudf::data_type{cudf::type_id::FLOAT64}, total_rep_count, values.release(), rmm::device_buffer{}, 0); + extract_repeated_in_nested_fixed_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + values.data(), + d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); } else if (elem_type_id == cudf::type_id::STRING) { rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_occs.begin(), d_rep_occs.end(), + d_rep_occs.begin(), + d_rep_occs.end(), d_str_lengths.begin(), [] __device__(repeated_occurrence const& occ) { return occ.length; }); - int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), - d_str_lengths.begin(), d_str_lengths.end(), 0); + int32_t total_chars = + thrust::reduce(rmm::exec_policy(stream), d_str_lengths.begin(), d_str_lengths.end(), 0); rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), - d_str_lengths.begin(), d_str_lengths.end(), - str_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + d_str_lengths.begin(), + d_str_lengths.end(), + str_offs.begin(), + 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - extract_repeated_in_nested_string_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, str_offs.data(), chars.data(), d_error.data()); + extract_repeated_in_nested_string_kernel<<<(total_rep_count + THREADS_PER_BLOCK - 1) / + THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + str_offs.data(), + chars.data(), + d_error.data()); } - auto str_offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, str_offs.release(), rmm::device_buffer{}, 0); - child_values = cudf::make_strings_column(total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count + 1, + str_offs.release(), + rmm::device_buffer{}, + 0); + child_values = cudf::make_strings_column( + total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); } else if (elem_type_id == cudf::type_id::LIST) { rmm::device_uvector d_len(total_rep_count, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_occs.begin(), d_rep_occs.end(), + d_rep_occs.begin(), + d_rep_occs.end(), d_len.begin(), [] __device__(repeated_occurrence const& occ) { return occ.length; }); - int32_t total_bytes = thrust::reduce(rmm::exec_policy(stream), - d_len.begin(), d_len.end(), 0); + int32_t total_bytes = + thrust::reduce(rmm::exec_policy(stream), d_len.begin(), d_len.end(), 0); rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_len.begin(), d_len.end(), - byte_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, &total_bytes, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_len.begin(), d_len.end(), byte_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector bytes(total_bytes, stream, mr); if (total_bytes > 0) { - extract_repeated_in_nested_string_kernel<<< - (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, stream.value()>>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_rep_occs.data(), total_rep_count, byte_offs.data(), bytes.data(), d_error.data()); + extract_repeated_in_nested_string_kernel<<<(total_rep_count + THREADS_PER_BLOCK - 1) / + THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_rep_occs.data(), + total_rep_count, + byte_offs.data(), + bytes.data(), + d_error.data()); } - auto offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, byte_offs.release(), rmm::device_buffer{}, 0); + auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count + 1, + byte_offs.release(), + rmm::device_buffer{}, + 0); auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, total_bytes, - rmm::device_buffer(bytes.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); - child_values = cudf::make_lists_column( - total_rep_count, std::move(offs_col), std::move(bytes_child), 0, rmm::device_buffer{}, stream, mr); + cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(bytes.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); + child_values = cudf::make_lists_column(total_rep_count, + std::move(offs_col), + std::move(bytes_child), + 0, + rmm::device_buffer{}, + stream, + mr); } else if (elem_type_id == cudf::type_id::STRUCT) { // Repeated message field (ArrayType(StructType)) inside nested message. // Build virtual parent info for each occurrence so we can recursively decode children. auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); if (gc_indices.empty()) { - child_values = cudf::make_structs_column( - total_rep_count, std::vector>{}, - 0, rmm::device_buffer{}, stream, mr); + child_values = cudf::make_structs_column(total_rep_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr); } else { rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - compute_virtual_parents_for_nested_repeated_kernel<<>>( - d_rep_occs.data(), list_offsets, d_parent_locs.data(), - d_virtual_row_offsets.data(), d_virtual_parent_locs.data(), total_rep_count); - - child_values = build_nested_struct_column( - message_data, d_virtual_row_offsets.data(), base_offset, d_virtual_parent_locs, - gc_indices, schema, num_fields, schema_output_types, default_ints, default_floats, - default_bools, default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, d_error, total_rep_count, stream, mr, depth + 1); + compute_virtual_parents_for_nested_repeated_kernel<<>>( + d_rep_occs.data(), + list_offsets, + d_parent_locs.data(), + d_virtual_row_offsets.data(), + d_virtual_parent_locs.data(), + total_rep_count); + + child_values = build_nested_struct_column(message_data, + d_virtual_row_offsets.data(), + base_offset, + d_virtual_parent_locs, + gc_indices, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + total_rep_count, + stream, + mr, + depth + 1); } } else { child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); } - auto list_offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); - struct_children.push_back(cudf::make_lists_column( - num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}, stream, mr)); + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + struct_children.push_back(cudf::make_lists_column(num_rows, + std::move(list_offs_col), + std::move(child_values), + 0, + rmm::device_buffer{}, + stream, + mr)); } continue; } @@ -3241,13 +3712,23 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_val); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_val); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::INT32: { @@ -3255,24 +3736,54 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, static_cast(def_int)); + extract_nested_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::UINT32: { @@ -3280,19 +3791,39 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, static_cast(def_int)); + extract_nested_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::INT64: { @@ -3300,24 +3831,54 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } else { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::UINT64: { @@ -3325,45 +3886,85 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, static_cast(def_int)); + extract_nested_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::FLOAT32: { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_float); + extract_nested_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_float); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::FLOAT64: { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); double def_double = has_def ? default_floats[child_schema_idx] : 0.0; - extract_nested_fixed_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_double); + extract_nested_fixed_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_double); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back( + std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); break; } case cudf::type_id::STRING: { @@ -3371,22 +3972,39 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - extract_nested_varint_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, out.data(), valid.data(), - num_rows, d_error.data(), has_def, def_int); + extract_nested_varint_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size())) { - auto const& valid_enums = enum_valid_values[child_schema_idx]; + auto const& valid_enums = enum_valid_values[child_schema_idx]; auto const& enum_name_bytes = enum_names[child_schema_idx]; if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), - valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); validate_enum_values_kernel<<>>( - out.data(), valid.data(), d_row_has_invalid_enum.data(), - d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows); std::vector h_name_offsets(valid_enums.size() + 1, 0); int32_t total_name_chars = 0; @@ -3404,46 +4022,78 @@ std::unique_ptr build_nested_struct_column( } rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_name_chars(total_name_chars, stream, mr); if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), h_name_chars.data(), - total_name_chars * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); } rmm::device_uvector lengths(num_rows, stream, mr); compute_enum_string_lengths_kernel<<>>( - out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), - static_cast(valid_enums.size()), lengths.data(), num_rows); + out.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + lengths.data(), + num_rows); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); + thrust::exclusive_scan(rmm::exec_policy(stream), + lengths.begin(), + lengths.end(), + output_offsets.begin(), + 0); int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + output_offsets.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { copy_enum_string_chars_kernel<<>>( - out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), - d_name_chars.data(), static_cast(valid_enums.size()), - output_offsets.data(), chars.data(), num_rows); + out.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + output_offsets.data(), + chars.data(), + num_rows); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); struct_children.push_back(cudf::make_strings_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); } else { @@ -3455,60 +4105,93 @@ std::unique_ptr build_nested_struct_column( struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { - bool has_def_str = has_def; + bool has_def_str = has_def; auto const& def_str = default_strings[child_schema_idx]; - int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; + int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; rmm::device_uvector d_default_str(def_len, stream, mr); if (has_def_str && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), + def_str.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); } rmm::device_uvector lengths(num_rows, stream, mr); extract_nested_lengths_kernel<<>>( - d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, - lengths.data(), num_rows, has_def_str, def_len); + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + lengths.data(), + num_rows, + has_def_str, + def_len); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + output_offsets.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { copy_nested_varlen_data_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, output_offsets.data(), - chars.data(), num_rows, has_def_str, d_default_str.data(), def_len); + message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + output_offsets.data(), + chars.data(), + num_rows, + has_def_str, + d_default_str.data(), + def_len); } rmm::device_uvector valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.begin(), - [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, num_child_fields, has_def_str] __device__(auto row) { - return (plocs[row].offset >= 0 && - flocs[row * num_child_fields + ci].offset >= 0) || has_def_str; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_str] __device__(auto row) { + return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + has_def_str; + }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); struct_children.push_back(cudf::make_strings_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); } @@ -3516,65 +4199,106 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::LIST: { // bytes (BinaryType) represented as LIST - bool has_def_bytes = has_def; + bool has_def_bytes = has_def; auto const& def_bytes = default_strings[child_schema_idx]; - int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; + int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; rmm::device_uvector d_default_bytes(def_len, stream, mr); if (has_def_bytes && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), def_bytes.data(), def_len, - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), + def_bytes.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); } rmm::device_uvector lengths(num_rows, stream, mr); extract_nested_lengths_kernel<<>>( - d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, - lengths.data(), num_rows, has_def_bytes, def_len); + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + lengths.data(), + num_rows, + has_def_bytes, + def_len); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); int32_t total_bytes = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, + output_offsets.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_bytes += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_bytes, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector bytes_data(total_bytes, stream, mr); if (total_bytes > 0) { copy_nested_varlen_data_kernel<<>>( - message_data, list_offsets, base_offset, d_parent_locs.data(), - d_child_locations.data(), ci, num_child_fields, output_offsets.data(), - bytes_data.data(), num_rows, has_def_bytes, d_default_bytes.data(), def_len); + message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + output_offsets.data(), + bytes_data.data(), + num_rows, + has_def_bytes, + d_default_bytes.data(), + def_len); } rmm::device_uvector valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.begin(), - [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, num_child_fields, has_def_bytes] __device__(auto row) { - return (plocs[row].offset >= 0 && - flocs[row * num_child_fields + ci].offset >= 0) || has_def_bytes; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_bytes] __device__(auto row) { + return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + has_def_bytes; + }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, total_bytes, - rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); - struct_children.push_back(cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask), stream, mr)); + cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); + struct_children.push_back(cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(bytes_child), + null_count, + std::move(mask), + stream, + mr)); break; } case cudf::type_id::STRUCT: { @@ -3585,29 +4309,45 @@ std::unique_ptr build_nested_struct_column( } rmm::device_uvector d_gc_parent(num_rows, stream, mr); compute_grandchild_parent_locations_kernel<<>>( - d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields, - d_gc_parent.data(), num_rows); - struct_children.push_back(build_nested_struct_column( - message_data, list_offsets, base_offset, d_gc_parent, gc_indices, - schema, num_fields, schema_output_types, default_ints, default_floats, default_bools, - default_strings, enum_valid_values, enum_names, d_row_has_invalid_enum, d_error, - num_rows, stream, mr, depth + 1)); + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + d_gc_parent.data(), + num_rows); + struct_children.push_back(build_nested_struct_column(message_data, + list_offsets, + base_offset, + d_gc_parent, + gc_indices, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + num_rows, + stream, + mr, + depth + 1)); break; } - default: - struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); - break; + default: struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); break; } } rmm::device_uvector struct_valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - struct_valid.begin(), - [plocs = d_parent_locs.data()] __device__(auto row) { - return plocs[row].offset >= 0; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + struct_valid.begin(), + [plocs = d_parent_locs.data()] __device__(auto row) { return plocs[row].offset >= 0; }); auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); return cudf::make_structs_column( num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); @@ -3635,9 +4375,9 @@ std::unique_ptr decode_protobuf_to_struct( "binary_input must be a LIST column"); auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto num_rows = binary_input.size(); - auto num_fields = static_cast(schema.size()); + auto mr = cudf::get_current_device_resource_ref(); + auto num_rows = binary_input.size(); + auto num_fields = static_cast(schema.size()); if (num_rows == 0 || num_fields == 0) { // Build empty struct based on top-level fields with proper nested structure @@ -3649,13 +4389,19 @@ std::unique_ptr decode_protobuf_to_struct( // Repeated message field - build empty LIST with proper struct element rmm::device_uvector offsets(1, stream, mr); int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets.data(), &zero, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + offsets.data(), &zero, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); auto offsets_col = std::make_unique( cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); auto empty_struct = make_empty_struct_column_with_schema( schema, schema_output_types, i, num_fields, stream, mr); - empty_children.push_back(cudf::make_lists_column(0, std::move(offsets_col), std::move(empty_struct), - 0, rmm::device_buffer{}, stream, mr)); + empty_children.push_back(cudf::make_lists_column(0, + std::move(offsets_col), + std::move(empty_struct), + 0, + rmm::device_buffer{}, + stream, + mr)); } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { // Non-repeated nested message field empty_children.push_back(make_empty_struct_column_with_schema( @@ -3665,29 +4411,30 @@ std::unique_ptr decode_protobuf_to_struct( } } } - return cudf::make_structs_column(0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } // Copy schema to device std::vector h_device_schema(num_fields); for (int i = 0; i < num_fields; i++) { - h_device_schema[i] = { - schema[i].field_number, - schema[i].parent_idx, - schema[i].depth, - schema[i].wire_type, - static_cast(schema[i].output_type), - schema[i].encoding, - schema[i].is_repeated, - schema[i].is_required, - schema[i].has_default_value - }; + h_device_schema[i] = {schema[i].field_number, + schema[i].parent_idx, + schema[i].depth, + schema[i].wire_type, + static_cast(schema[i].output_type), + schema[i].encoding, + schema[i].is_repeated, + schema[i].is_required, + schema[i].has_default_value}; } rmm::device_uvector d_schema(num_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), h_device_schema.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), + h_device_schema.data(), num_fields * sizeof(device_nested_field_descriptor), - cudaMemcpyHostToDevice, stream.value())); + cudaMemcpyHostToDevice, + stream.value())); auto d_in = cudf::column_device_view::create(binary_input, stream); @@ -3709,8 +4456,8 @@ std::unique_ptr decode_protobuf_to_struct( } int num_repeated = static_cast(repeated_field_indices.size()); - int num_nested = static_cast(nested_field_indices.size()); - int num_scalar = static_cast(scalar_field_indices.size()); + int num_nested = static_cast(nested_field_indices.size()); + int num_scalar = static_cast(scalar_field_indices.size()); // Error flag rmm::device_uvector d_error(1, stream, mr); @@ -3721,12 +4468,12 @@ std::unique_ptr decode_protobuf_to_struct( enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? num_rows : 0, stream, mr); if (has_enum_fields) { - CUDF_CUDA_TRY(cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, - num_rows * sizeof(bool), stream.value())); + CUDF_CUDA_TRY( + cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, num_rows * sizeof(bool), stream.value())); } auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1) / threads); + auto const blocks = static_cast((num_rows + threads - 1) / threads); // Allocate for counting repeated fields rmm::device_uvector d_repeated_info( @@ -3738,47 +4485,54 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector d_nested_indices(num_nested > 0 ? num_nested : 1, stream, mr); if (num_repeated > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_repeated_indices.data(), repeated_field_indices.data(), - num_repeated * sizeof(int), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_repeated_indices.data(), + repeated_field_indices.data(), + num_repeated * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); } if (num_nested > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_indices.data(), nested_field_indices.data(), - num_nested * sizeof(int), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_indices.data(), + nested_field_indices.data(), + num_nested * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); } // Count repeated fields at depth 0 if (num_repeated > 0 || num_nested > 0) { - count_repeated_fields_kernel<<>>( - *d_in, - d_schema.data(), - num_fields, - 0, // depth_level - d_repeated_info.data(), - num_repeated, - d_repeated_indices.data(), - d_nested_locations.data(), - num_nested, - d_nested_indices.data(), - d_error.data()); + count_repeated_fields_kernel<<>>(*d_in, + d_schema.data(), + num_fields, + 0, // depth_level + d_repeated_info.data(), + num_repeated, + d_repeated_indices.data(), + d_nested_locations.data(), + num_nested, + d_nested_indices.data(), + d_error.data()); } // For scalar fields at depth 0, use the existing scan_all_fields_kernel // Use a map to store columns by schema index, then assemble in order at the end std::map> column_map; - + // Process scalar fields using existing infrastructure if (num_scalar > 0) { std::vector h_field_descs(num_scalar); for (int i = 0; i < num_scalar; i++) { - int schema_idx = scalar_field_indices[i]; - h_field_descs[i].field_number = schema[schema_idx].field_number; + int schema_idx = scalar_field_indices[i]; + h_field_descs[i].field_number = schema[schema_idx].field_number; h_field_descs[i].expected_wire_type = schema[schema_idx].wire_type; } rmm::device_uvector d_field_descs(num_scalar, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), h_field_descs.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), + h_field_descs.data(), num_scalar * sizeof(field_descriptor), - cudaMemcpyHostToDevice, stream.value())); + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_locations( static_cast(num_rows) * num_scalar, stream, mr); @@ -3791,7 +4545,10 @@ std::unique_ptr decode_protobuf_to_struct( bool has_required = false; for (int i = 0; i < num_scalar; i++) { int si = scalar_field_indices[i]; - if (schema[si].is_required) { has_required = true; break; } + if (schema[si].is_required) { + has_required = true; + break; + } } if (has_required) { rmm::device_uvector d_is_required(num_scalar, stream, mr); @@ -3799,8 +4556,11 @@ std::unique_ptr decode_protobuf_to_struct( for (int i = 0; i < num_scalar; i++) { h_is_required[i] = schema[scalar_field_indices[i]].is_required ? 1 : 0; } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), h_is_required.data(), - num_scalar * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), + h_is_required.data(), + num_scalar * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); check_required_fields_kernel<<>>( d_locations.data(), d_is_required.data(), num_scalar, num_rows, d_error.data()); } @@ -3808,30 +4568,41 @@ std::unique_ptr decode_protobuf_to_struct( // Extract scalar values (reusing existing extraction logic) cudf::lists_column_view const in_list_view(binary_input); - auto const* message_data = reinterpret_cast(in_list_view.child().data()); + auto const* message_data = + reinterpret_cast(in_list_view.child().data()); auto const* list_offsets = in_list_view.offsets().data(); cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); for (int i = 0; i < num_scalar; i++) { int schema_idx = scalar_field_indices[i]; - auto const dt = schema_output_types[schema_idx]; + auto const dt = schema_output_types[schema_idx]; auto const enc = schema[schema_idx].encoding; - bool has_def = schema[schema_idx].has_default_value; + bool has_def = schema[schema_idx].has_default_value; switch (dt.id()) { case cudf::type_id::BOOL8: { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); int64_t def_val = has_def ? (default_bools[schema_idx] ? 1 : 0) : 0; - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_val); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_val); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3840,32 +4611,69 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, static_cast(def_int)); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } // Enum validation: check if this INT32 field has valid enum values if (schema_idx < static_cast(enum_valid_values.size())) { auto const& valid_enums = enum_valid_values[schema_idx]; if (!valid_enums.empty()) { rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), - valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); validate_enum_values_kernel<<>>( - out.data(), valid.data(), d_row_has_invalid_enum.data(), - d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows); } } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3874,16 +4682,36 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, static_cast(def_int)); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3892,20 +4720,50 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3914,16 +4772,36 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, static_cast(def_int)); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + static_cast(def_int)); } else { - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3931,11 +4809,21 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); float def_float = has_def ? static_cast(default_floats[schema_idx]) : 0.0f; - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_float); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_float); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3943,11 +4831,21 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); double def_double = has_def ? default_floats[schema_idx] : 0.0; - extract_fixed_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_double); + extract_fixed_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_double); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( + column_map[schema_idx] = std::make_unique( dt, num_rows, out.release(), std::move(mask), null_count); break; } @@ -3960,22 +4858,39 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid(num_rows, stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; - extract_varint_from_locations_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - out.data(), valid.data(), num_rows, d_error.data(), has_def, def_int); + extract_varint_from_locations_kernel + <<>>(message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + out.data(), + valid.data(), + num_rows, + d_error.data(), + has_def, + def_int); if (schema_idx < static_cast(enum_valid_values.size()) && schema_idx < static_cast(enum_names.size())) { - auto const& valid_enums = enum_valid_values[schema_idx]; + auto const& valid_enums = enum_valid_values[schema_idx]; auto const& enum_name_bytes = enum_names[schema_idx]; if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { // Validate enum numeric values first. rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), - valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); validate_enum_values_kernel<<>>( - out.data(), valid.data(), d_row_has_invalid_enum.data(), - d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows); // Build flattened enum-name chars and offsets on host, then copy to device. std::vector h_name_offsets(valid_enums.size() + 1, 0); @@ -3994,48 +4909,80 @@ std::unique_ptr decode_protobuf_to_struct( } rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_name_chars(total_name_chars, stream, mr); if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), h_name_chars.data(), - total_name_chars * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); } // Compute output UTF-8 lengths rmm::device_uvector lengths(num_rows, stream, mr); compute_enum_string_lengths_kernel<<>>( - out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), - static_cast(valid_enums.size()), lengths.data(), num_rows); + out.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + lengths.data(), + num_rows); // Prefix sum for string offsets rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); + thrust::exclusive_scan(rmm::exec_policy(stream), + lengths.begin(), + lengths.end(), + output_offsets.begin(), + 0); int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + output_offsets.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { copy_enum_string_chars_kernel<<>>( - out.data(), valid.data(), d_valid_enums.data(), d_name_offsets.data(), - d_name_chars.data(), static_cast(valid_enums.size()), - output_offsets.data(), chars.data(), num_rows); + out.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + output_offsets.data(), + chars.data(), + num_rows); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); column_map[schema_idx] = cudf::make_strings_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); } else { @@ -4049,14 +4996,17 @@ std::unique_ptr decode_protobuf_to_struct( } } else { // Regular protobuf STRING (length-delimited) - bool has_def_str = has_def; + bool has_def_str = has_def; auto const& def_str = default_strings[schema_idx]; - int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; + int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; rmm::device_uvector d_default_str(def_len, stream, mr); if (has_def_str && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), def_str.data(), def_len, - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), + def_str.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); } // Extract string lengths @@ -4066,43 +5016,64 @@ std::unique_ptr decode_protobuf_to_struct( // Compute offsets via prefix sum rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + output_offsets.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Copy string data rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { copy_scalar_string_data_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - output_offsets.data(), chars.data(), num_rows, has_def_str, - d_default_str.data(), def_len); + message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + output_offsets.data(), + chars.data(), + num_rows, + has_def_str, + d_default_str.data(), + def_len); } // Build validity mask rmm::device_uvector valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.begin(), - [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_str; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_str; + }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); column_map[schema_idx] = cudf::make_strings_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); } @@ -4110,14 +5081,17 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::LIST: { // bytes (BinaryType) represented as LIST - bool has_def_bytes = has_def; + bool has_def_bytes = has_def; auto const& def_bytes = default_strings[schema_idx]; - int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; + int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; rmm::device_uvector d_default_bytes(def_len, stream, mr); if (has_def_bytes && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), def_bytes.data(), def_len, - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), + def_bytes.data(), + def_len, + cudaMemcpyHostToDevice, + stream.value())); } rmm::device_uvector lengths(num_rows, stream, mr); @@ -4125,46 +5099,75 @@ std::unique_ptr decode_protobuf_to_struct( d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_bytes, def_len); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), lengths.begin(), lengths.end(), - output_offsets.begin(), 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); int32_t total_bytes = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, output_offsets.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, + output_offsets.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, lengths.data() + num_rows - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + lengths.data() + num_rows - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_bytes += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, &total_bytes, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector bytes_data(total_bytes, stream, mr); if (total_bytes > 0) { copy_scalar_string_data_kernel<<>>( - message_data, list_offsets, base_offset, d_locations.data(), i, num_scalar, - output_offsets.data(), bytes_data.data(), num_rows, has_def_bytes, - d_default_bytes.data(), def_len); + message_data, + list_offsets, + base_offset, + d_locations.data(), + i, + num_scalar, + output_offsets.data(), + bytes_data.data(), + num_rows, + has_def_bytes, + d_default_bytes.data(), + def_len); } rmm::device_uvector valid(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.begin(), - [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__(auto row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; - }); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.begin(), + [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__(auto row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; + }); auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, output_offsets.release(), - rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, total_bytes, - rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), rmm::device_buffer{}, 0); - column_map[schema_idx] = cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask), stream, mr); + cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); + column_map[schema_idx] = cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(bytes_child), + null_count, + std::move(mask), + stream, + mr); break; } default: @@ -4178,16 +5181,18 @@ std::unique_ptr decode_protobuf_to_struct( // Process repeated fields if (num_repeated > 0) { std::vector h_repeated_info(static_cast(num_rows) * num_repeated); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_repeated_info.data(), d_repeated_info.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(h_repeated_info.data(), + d_repeated_info.data(), h_repeated_info.size() * sizeof(repeated_field_info), - cudaMemcpyDeviceToHost, stream.value())); + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); cudf::lists_column_view const in_list_view(binary_input); auto const* list_offsets = in_list_view.offsets().data(); for (int ri = 0; ri < num_repeated; ri++) { - int schema_idx = repeated_field_indices[ri]; + int schema_idx = repeated_field_indices[ri]; auto element_type = schema_output_types[schema_idx]; // Get per-row counts for this repeated field entirely on GPU (performance fix!) @@ -4197,9 +5202,9 @@ std::unique_ptr decode_protobuf_to_struct( thrust::make_counting_iterator(num_rows), d_field_counts.begin(), extract_strided_count{d_repeated_info.data(), ri, num_repeated}); - - int total_count = thrust::reduce(rmm::exec_policy(stream), - d_field_counts.begin(), d_field_counts.end(), 0); + + int total_count = + thrust::reduce(rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), 0); // Still need host-side field_info for build_repeated_scalar_column std::vector field_info(num_rows); @@ -4207,54 +5212,93 @@ std::unique_ptr decode_protobuf_to_struct( field_info[row] = h_repeated_info[row * num_repeated + ri]; } - if (total_count > 0) { // Build offsets for occurrence scanning on GPU (performance fix!) rmm::device_uvector d_occ_offsets(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), - d_field_counts.begin(), d_field_counts.end(), - d_occ_offsets.begin(), 0); + d_field_counts.begin(), + d_field_counts.end(), + d_occ_offsets.begin(), + 0); // Set last element - CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data() + num_rows, &total_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // Scan for all occurrences rmm::device_uvector d_occurrences(total_count, stream, mr); scan_repeated_field_occurrences_kernel<<>>( - *d_in, d_schema.data(), schema_idx, 0, d_occ_offsets.data(), - d_occurrences.data(), d_error.data()); + *d_in, + d_schema.data(), + schema_idx, + 0, + d_occ_offsets.data(), + d_occurrences.data(), + d_error.data()); // Build the appropriate column type based on element type // For now, support scalar repeated fields auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); - + // The output_type in schema is the LIST type, but we need element type // For repeated int32, output_type should indicate the element is INT32 switch (child_type_id) { case cudf::type_id::INT32: - column_map[schema_idx] = build_repeated_scalar_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, stream, mr); + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); break; case cudf::type_id::INT64: - column_map[schema_idx] = build_repeated_scalar_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, stream, mr); + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); break; case cudf::type_id::FLOAT32: - column_map[schema_idx] = build_repeated_scalar_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, stream, mr); + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); break; case cudf::type_id::FLOAT64: - column_map[schema_idx] = build_repeated_scalar_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, stream, mr); + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); break; case cudf::type_id::BOOL8: - column_map[schema_idx] = build_repeated_scalar_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, stream, mr); + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); break; case cudf::type_id::STRING: { auto enc = schema[schema_idx].encoding; @@ -4265,28 +5309,37 @@ std::unique_ptr decode_protobuf_to_struct( enum_valid_values[schema_idx].size() == enum_names[schema_idx].size()) { // Repeated enum-as-string: extract varints, then convert to strings. auto const& valid_enums = enum_valid_values[schema_idx]; - auto const& name_bytes = enum_names[schema_idx]; + auto const& name_bytes = enum_names[schema_idx]; cudf::lists_column_view const in_lv(binary_input); auto const* msg_data = reinterpret_cast(in_lv.child().data()); - auto const* loffs = in_lv.offsets().data(); + auto const* loffs = in_lv.offsets().data(); cudf::size_type boff = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&boff, loffs, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &boff, loffs, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); - auto const rep_blocks = static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); - extract_repeated_varint_kernel<<>>( - msg_data, loffs, boff, d_occurrences.data(), total_count, - enum_ints.data(), d_error.data()); + auto const rep_blocks = + static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); + extract_repeated_varint_kernel + <<>>(msg_data, + loffs, + boff, + d_occurrences.data(), + total_count, + enum_ints.data(), + d_error.data()); // 2. Build device-side enum lookup tables rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), - valid_enums.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); std::vector h_name_offsets(valid_enums.size() + 1, 0); int32_t total_name_chars = 0; @@ -4303,12 +5356,18 @@ std::unique_ptr decode_protobuf_to_struct( } } rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); rmm::device_uvector d_name_chars(total_name_chars, stream, mr); if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), h_name_chars.data(), - total_name_chars * sizeof(uint8_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); } // 3. Validate enum values (sets row_has_invalid_enum for PERMISSIVE mode). @@ -4320,78 +5379,135 @@ std::unique_ptr decode_protobuf_to_struct( // 4. Compute per-element string lengths rmm::device_uvector elem_lengths(total_count, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_ints.data(), elem_valid.data(), d_valid_enums.data(), d_name_offsets.data(), - static_cast(valid_enums.size()), elem_lengths.data(), total_count); + compute_enum_string_lengths_kernel<<>>( + enum_ints.data(), + elem_valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + elem_lengths.data(), + total_count); // 5. Build string offsets rmm::device_uvector str_offsets(total_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), elem_lengths.begin(), elem_lengths.end(), - str_offsets.begin(), 0); + thrust::exclusive_scan(rmm::exec_policy(stream), + elem_lengths.begin(), + elem_lengths.end(), + str_offsets.begin(), + 0); int32_t total_chars = 0; if (total_count > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, str_offsets.data() + total_count - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, + str_offsets.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, elem_lengths.data() + total_count - 1, - sizeof(int32_t), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, + elem_lengths.data() + total_count - 1, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); total_chars += last_len; } - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, &total_chars, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); // 6. Copy string chars rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { copy_enum_string_chars_kernel<<>>( - enum_ints.data(), elem_valid.data(), d_valid_enums.data(), d_name_offsets.data(), - d_name_chars.data(), static_cast(valid_enums.size()), - str_offsets.data(), chars.data(), total_count); + enum_ints.data(), + elem_valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + str_offsets.data(), + chars.data(), + total_count); } // 7. Assemble LIST column - auto str_offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, total_count + 1, str_offsets.release(), - rmm::device_buffer{}, 0); + auto str_offs_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + str_offsets.release(), + rmm::device_buffer{}, + 0); auto child_col = cudf::make_strings_column( total_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); // Build list offsets from per-row counts rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), - d_field_counts.begin(), d_field_counts.end(), - list_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, - sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - - auto list_offs_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), - rmm::device_buffer{}, 0); + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + auto list_offs_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); auto input_null_count = binary_input.null_count(); if (input_null_count > 0) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - column_map[schema_idx] = cudf::make_lists_column( - num_rows, std::move(list_offs_col), std::move(child_col), - input_null_count, std::move(null_mask), stream, mr); + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + column_map[schema_idx] = cudf::make_lists_column(num_rows, + std::move(list_offs_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); } else { - column_map[schema_idx] = cudf::make_lists_column( - num_rows, std::move(list_offs_col), std::move(child_col), - 0, rmm::device_buffer{}, stream, mr); + column_map[schema_idx] = cudf::make_lists_column(num_rows, + std::move(list_offs_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); } } else { - column_map[schema_idx] = build_repeated_string_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, false, stream, mr); + column_map[schema_idx] = build_repeated_string_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + false, + stream, + mr); } break; } case cudf::type_id::LIST: // bytes as LIST - column_map[schema_idx] = build_repeated_string_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, true, stream, mr); + column_map[schema_idx] = build_repeated_string_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + true, + stream, + mr); break; case cudf::type_id::STRUCT: { // Repeated message field - ArrayType(StructType) @@ -4400,13 +5516,26 @@ std::unique_ptr decode_protobuf_to_struct( // No child fields - create null column column_map[schema_idx] = make_null_column(element_type, num_rows, stream, mr); } else { - column_map[schema_idx] = build_repeated_struct_column( - binary_input, h_device_schema[schema_idx], field_info, d_occurrences, - total_count, num_rows, h_device_schema, child_field_indices, - schema_output_types, default_ints, default_floats, default_bools, - default_strings, schema, enum_valid_values, enum_names, - d_row_has_invalid_enum, d_error, - stream, mr); + column_map[schema_idx] = build_repeated_struct_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + h_device_schema, + child_field_indices, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + schema, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr); } break; } @@ -4419,9 +5548,12 @@ std::unique_ptr decode_protobuf_to_struct( // All rows have count=0 - create list of empty elements rmm::device_uvector offsets(num_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); - + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + // Build appropriate empty child column std::unique_ptr child_col; auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); @@ -4432,15 +5564,25 @@ std::unique_ptr decode_protobuf_to_struct( } else { child_col = make_empty_column_safe(schema_output_types[schema_idx], stream, mr); } - + auto const input_null_count = binary_input.null_count(); if (input_null_count > 0) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - column_map[schema_idx] = cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), input_null_count, std::move(null_mask), stream, mr); + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + column_map[schema_idx] = cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); } else { - column_map[schema_idx] = cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + column_map[schema_idx] = cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); } } } @@ -4450,18 +5592,21 @@ std::unique_ptr decode_protobuf_to_struct( if (num_nested > 0) { // Copy nested locations to host for processing std::vector h_nested_locations(static_cast(num_rows) * num_nested); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_nested_locations.data(), d_nested_locations.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(h_nested_locations.data(), + d_nested_locations.data(), h_nested_locations.size() * sizeof(field_location), - cudaMemcpyDeviceToHost, stream.value())); + cudaMemcpyDeviceToHost, + stream.value())); stream.synchronize(); cudf::lists_column_view const in_list_view(binary_input); - auto const* message_data = reinterpret_cast(in_list_view.child().data()); + auto const* message_data = + reinterpret_cast(in_list_view.child().data()); auto const* list_offsets = in_list_view.offsets().data(); cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&base_offset, list_offsets, sizeof(cudf::size_type), - cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); for (int ni = 0; ni < num_nested; ni++) { @@ -4472,8 +5617,8 @@ std::unique_ptr decode_protobuf_to_struct( if (child_field_indices.empty()) { // No child fields - create empty struct - column_map[parent_schema_idx] = make_null_column( - schema_output_types[parent_schema_idx], num_rows, stream, mr); + column_map[parent_schema_idx] = + make_null_column(schema_output_types[parent_schema_idx], num_rows, stream, mr); continue; } @@ -4483,15 +5628,31 @@ std::unique_ptr decode_protobuf_to_struct( for (int row = 0; row < num_rows; row++) { h_parent_locs[row] = h_nested_locations[row * num_nested + ni]; } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), h_parent_locs.data(), + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), + h_parent_locs.data(), num_rows * sizeof(field_location), - cudaMemcpyHostToDevice, stream.value())); - column_map[parent_schema_idx] = build_nested_struct_column( - message_data, list_offsets, base_offset, d_parent_locs, - child_field_indices, schema, num_fields, schema_output_types, - default_ints, default_floats, default_bools, default_strings, - enum_valid_values, enum_names, d_row_has_invalid_enum, d_error, - num_rows, stream, mr, 0); + cudaMemcpyHostToDevice, + stream.value())); + column_map[parent_schema_idx] = build_nested_struct_column(message_data, + list_offsets, + base_offset, + d_parent_locs, + child_field_indices, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + num_rows, + stream, + mr, + 0); } } @@ -4504,7 +5665,8 @@ std::unique_ptr decode_protobuf_to_struct( top_level_children.push_back(std::move(it->second)); } else { // Field not processed - create null column - top_level_children.push_back(make_null_column(schema_output_types[i], num_rows, stream, mr)); + top_level_children.push_back( + make_null_column(schema_output_types[i], num_rows, stream, mr)); } } } @@ -4512,11 +5674,12 @@ std::unique_ptr decode_protobuf_to_struct( // Check for errors CUDF_CUDA_TRY(cudaPeekAtLastError()); int h_error = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + CUDF_CUDA_TRY( + cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); if (fail_on_errors) { CUDF_EXPECTS(h_error == 0, - "Malformed protobuf message, unsupported wire type, or missing required field"); + "Malformed protobuf message, unsupported wire type, or missing required field"); } // Build final struct with PERMISSIVE mode null mask for invalid enums @@ -4530,8 +5693,9 @@ std::unique_ptr decode_protobuf_to_struct( [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { return !row_invalid[row]; }, - stream, mr); - struct_mask = std::move(mask); + stream, + mr); + struct_mask = std::move(mask); struct_null_count = null_count; } diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 3fc3e7dc97..c8d94c3fc2 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -26,9 +26,9 @@ namespace spark_rapids_jni { // Encoding constants -constexpr int ENC_DEFAULT = 0; -constexpr int ENC_FIXED = 1; -constexpr int ENC_ZIGZAG = 2; +constexpr int ENC_DEFAULT = 0; +constexpr int ENC_FIXED = 1; +constexpr int ENC_ZIGZAG = 2; constexpr int ENC_ENUM_STRING = 3; // Maximum nesting depth for nested messages @@ -39,15 +39,15 @@ constexpr int MAX_NESTING_DEPTH = 10; * Used to represent flattened schema with parent-child relationships. */ struct nested_field_descriptor { - int field_number; // Protobuf field number - int parent_idx; // Index of parent field in schema (-1 for top-level) - int depth; // Nesting depth (0 for top-level) - int wire_type; // Expected wire type + int field_number; // Protobuf field number + int parent_idx; // Index of parent field in schema (-1 for top-level) + int depth; // Nesting depth (0 for top-level) + int wire_type; // Expected wire type cudf::type_id output_type; // Output cudf type - int encoding; // Encoding type (ENC_DEFAULT, ENC_FIXED, ENC_ZIGZAG) - bool is_repeated; // Whether this field is repeated (array) - bool is_required; // Whether this field is required (proto2) - bool has_default_value; // Whether this field has a default value + int encoding; // Encoding type (ENC_DEFAULT, ENC_FIXED, ENC_ZIGZAG) + bool is_repeated; // Whether this field is repeated (array) + bool is_required; // Whether this field is required (proto2) + bool has_default_value; // Whether this field has a default value }; /** From 828e3c0524fe40e78538f4a450c9387ab91cd0bb Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 25 Feb 2026 16:20:27 +0800 Subject: [PATCH 025/107] clean up code Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 21 +- src/main/cpp/src/protobuf.cu | 1767 ++++++++++++++---------------- src/main/cpp/src/protobuf.hpp | 39 +- 3 files changed, 864 insertions(+), 963 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index d76d58f59e..de7ca90355 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -187,16 +187,17 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } } - auto result = spark_rapids_jni::decode_protobuf_to_struct(*input, - schema, - schema_output_types, - default_int_values, - default_float_values, - default_bool_values, - default_string_values, - enum_values, - enum_name_values, - fail_on_errors); + spark_rapids_jni::ProtobufDecodeContext context{std::move(schema), + std::move(schema_output_types), + std::move(default_int_values), + std::move(default_float_values), + std::move(default_bool_values), + std::move(default_string_values), + std::move(enum_values), + std::move(enum_name_values), + static_cast(fail_on_errors)}; + + auto result = spark_rapids_jni::decode_protobuf_to_struct(*input, context); return cudf::jni::release_as_jlong(result); } diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a678a9b332..a7d42e0967 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -666,6 +666,243 @@ __global__ void scan_repeated_field_occurrences_kernel( // Pass 2: Extract data kernels // ============================================================================ +// ============================================================================ +// Data Extraction Location Providers +// ============================================================================ + +struct TopLevelLocationProvider { + cudf::size_type const* offsets; + cudf::size_type base_offset; + field_location const* locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto loc = locations[thread_idx * num_fields + field_idx]; + if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } + return loc; + } +}; + +struct RepeatedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct NestedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto ploc = parent_locations[thread_idx]; + auto cloc = child_locations[thread_idx * num_fields + field_idx]; + if (ploc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +struct NestedRepeatedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + auto ploc = parent_locations[occ.row_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct RepeatedMsgChildLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* msg_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto mloc = msg_locations[thread_idx]; + auto cloc = child_locations[thread_idx * num_fields + field_idx]; + if (mloc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +template +__global__ void extract_varint_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = static_cast(default_value); + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + uint8_t const* cur_end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + atomicExch(error_flag, ERR_VARINT); + if (valid) valid[idx] = false; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + out[idx] = static_cast(v); + if (valid) valid[idx] = true; +} + +template +__global__ void extract_fixed_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = default_value; + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == WT_32BIT) { + if (loc.length < 4) { + atomicExch(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + atomicExch(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[idx] = value; + if (valid) valid[idx] = true; +} + +template +__global__ void extract_lengths_kernel(LocationProvider loc_provider, + int total_items, + int32_t* out_lengths, + bool has_default = false, + int32_t default_length = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset >= 0) { + out_lengths[idx] = loc.length; + } else if (has_default) { + out_lengths[idx] = default_length; + } else { + out_lengths[idx] = 0; + } +} +template +__global__ void copy_varlen_data_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + cudf::size_type const* output_offsets, + char* output_chars, + int* error_flag, + bool has_default = false, + uint8_t const* default_chars = nullptr, + int default_len = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + auto out_start = output_offsets[idx]; + + if (loc.offset < 0) { + if (has_default && default_len > 0) { + for (int i = 0; i < default_len; i++) { + output_chars[out_start + i] = static_cast(default_chars[i]); + } + } + return; + } + + uint8_t const* src = message_data + data_offset; + for (int i = 0; i < loc.length; i++) { + output_chars[out_start + i] = static_cast(src[i]); + } +} + /** * Extract varint field data using pre-recorded locations. * Supports default values for missing fields. @@ -854,44 +1091,6 @@ __global__ void extract_repeated_fixed_kernel(uint8_t const* message_data, out[idx] = value; } -/** - * Copy repeated variable-length data (string/bytes) using pre-recorded occurrences. - */ -__global__ void copy_repeated_varlen_data_kernel( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - repeated_occurrence const* occurrences, - int total_occurrences, - int32_t const* output_offsets, // Pre-computed output offsets for strings - char* output_data) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_occurrences) return; - - auto const& occ = occurrences[idx]; - if (occ.length == 0) return; - - auto row_start = row_offsets[occ.row_idx] - base_offset; - uint8_t const* src = message_data + row_start + occ.offset; - char* dst = output_data + output_offsets[idx]; - - memcpy(dst, src, occ.length); -} - -/** - * Extract lengths from repeated occurrences for prefix sum. - */ -__global__ void extract_repeated_lengths_kernel(repeated_occurrence const* occurrences, - int total_occurrences, - int32_t* lengths) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_occurrences) return; - - lengths[idx] = occurrences[idx].length; -} - // ============================================================================ // Nested message scanning kernels // ============================================================================ @@ -1230,107 +1429,6 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, } } -/** - * Extract varint values from repeated field occurrences within nested messages. - */ -template -__global__ void extract_repeated_in_nested_varint_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - repeated_occurrence const* occurrences, - int total_count, - OutputType* out, - int* error_flag) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_count) return; - - auto const& occ = occurrences[idx]; - auto const& parent_loc = parent_locs[occ.row_idx]; - - cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; - uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; - uint8_t const* msg_end = message_data + row_off + parent_loc.offset + parent_loc.length; - uint8_t const* varint_end = - (data_ptr + MAX_VARINT_BYTES < msg_end) ? (data_ptr + MAX_VARINT_BYTES) : msg_end; - - uint64_t val; - int vbytes; - if (!read_varint(data_ptr, varint_end, val, vbytes)) { - atomicExch(error_flag, ERR_VARINT); - return; - } - - if constexpr (ZigZag) { val = (val >> 1) ^ (-(val & 1)); } - - out[idx] = static_cast(val); -} - -template -__global__ void extract_repeated_in_nested_fixed_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - repeated_occurrence const* occurrences, - int total_count, - OutputType* out, - int* error_flag) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_count) return; - - auto const& occ = occurrences[idx]; - auto const& parent_loc = parent_locs[occ.row_idx]; - - cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; - uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; - - if constexpr (WT == WT_32BIT) { - if (occ.length < 4) { - atomicExch(error_flag, ERR_FIXED_LEN); - out[idx] = OutputType{}; - return; - } - uint32_t raw = load_le(data_ptr); - memcpy(&out[idx], &raw, sizeof(OutputType)); - } else { - if (occ.length < 8) { - atomicExch(error_flag, ERR_FIXED_LEN); - out[idx] = OutputType{}; - return; - } - uint64_t raw = load_le(data_ptr); - memcpy(&out[idx], &raw, sizeof(OutputType)); - } -} - -/** - * Extract string values from repeated field occurrences within nested messages. - */ -__global__ void extract_repeated_in_nested_string_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - repeated_occurrence const* occurrences, - int total_count, - int32_t const* str_offsets, - char* chars, - int* error_flag) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_count) return; - - auto const& occ = occurrences[idx]; - auto const& parent_loc = parent_locs[occ.row_idx]; - - cudf::size_type row_off = row_offsets[occ.row_idx] - base_offset; - uint8_t const* data_ptr = message_data + row_off + parent_loc.offset + occ.offset; - - int32_t out_offset = str_offsets[idx]; - memcpy(chars + out_offset, data_ptr, occ.length); -} - /** * Extract varint child fields from repeated message occurrences. */ @@ -1439,54 +1537,6 @@ __global__ void extract_repeated_msg_child_fixed_kernel(uint8_t const* message_d * Kernel to extract string data from repeated message child fields. * Copies all strings in parallel on the GPU instead of per-string host copies. */ -__global__ void extract_repeated_msg_child_strings_kernel( - uint8_t const* message_data, - int32_t const* msg_row_offsets, - field_location const* msg_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - int32_t const* string_offsets, // Output offsets (exclusive scan of lengths) - char* output_chars, - bool* valid, - int total_count) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; - - auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; - - if (field_loc.offset < 0 || field_loc.length == 0) { - valid[idx] = false; - return; - } - - valid[idx] = true; - - int32_t row_offset = msg_row_offsets[idx]; - int32_t msg_offset = msg_locs[idx].offset; - uint8_t const* str_src = message_data + row_offset + msg_offset + field_loc.offset; - char* str_dst = output_chars + string_offsets[idx]; - - // Copy string data - memcpy(str_dst, str_src, field_loc.length); -} - -/** - * Kernel to compute string lengths from child field locations. - */ -__global__ void compute_string_lengths_kernel(field_location const* child_locs, - int child_idx, - int num_child_fields, - int32_t* lengths, - int total_count) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; - - auto const& loc = child_locs[idx * num_child_fields + child_idx]; - lengths[idx] = (loc.offset >= 0) ? loc.length : 0; -} /** * Helper to build string column for repeated message child fields. @@ -1509,10 +1559,17 @@ inline std::unique_ptr build_repeated_msg_child_string_column( auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; - // Compute string lengths on GPU + // Compute string lengths on GPU using child_locs directly rmm::device_uvector d_lengths(total_count, stream, mr); - compute_string_lengths_kernel<<>>( - d_child_locs.data(), child_idx, num_child_fields, d_lengths.data(), total_count); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_lengths.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + auto const& loc = child_locs[idx * ncf + ci]; + return loc.offset >= 0 ? loc.length : 0; + }); // Compute offsets via exclusive scan rmm::device_uvector d_str_offsets(total_count + 1, stream, mr); @@ -1544,30 +1601,33 @@ inline std::unique_ptr build_repeated_msg_child_string_column( // Allocate output chars and validity rmm::device_uvector d_chars(total_chars, stream, mr); - rmm::device_uvector d_valid(total_count, stream, mr); + rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); - // Extract all strings in parallel on GPU (critical performance fix!) + // Set validity for all entries + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + return child_locs[idx * ncf + ci].offset >= 0; + }); + + // Extract all strings in parallel on GPU if (total_chars > 0) { - extract_repeated_msg_child_strings_kernel<<>>( - message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - child_idx, - num_child_fields, - d_str_offsets.data(), - d_chars.data(), - d_valid.data(), - total_count); - } else { - // No strings, just set validity - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_valid.begin(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__( - int idx) { return child_locs[idx * ncf + ci].offset >= 0; }); + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + total_count, + d_str_offsets.data(), + d_chars.data(), + d_error.data()); } auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); @@ -1616,12 +1676,19 @@ inline std::unique_ptr build_repeated_msg_child_bytes_column( auto const blocks = (total_count + threads - 1) / threads; rmm::device_uvector d_lengths(total_count, stream, mr); - compute_string_lengths_kernel<<>>( - d_child_locs.data(), child_idx, num_child_fields, d_lengths.data(), total_count); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_lengths.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + auto const& loc = child_locs[idx * ncf + ci]; + return loc.offset >= 0 ? loc.length : 0; + }); rmm::device_uvector d_offs(total_count + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), d_offs.begin(), 0); + rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), d_offs.data(), 0); int32_t total_bytes = 0; int32_t last_len = 0; @@ -1644,28 +1711,28 @@ inline std::unique_ptr build_repeated_msg_child_bytes_column( stream.value())); rmm::device_uvector d_bytes(total_bytes, stream, mr); - rmm::device_uvector d_valid(total_count, stream, mr); + rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); + + // Set validity for all entries + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + return child_locs[idx * ncf + ci].offset >= 0; + }); if (total_bytes > 0) { - extract_repeated_msg_child_strings_kernel<<>>( - message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - child_idx, - num_child_fields, - d_offs.data(), - d_bytes.data(), - d_valid.data(), - total_count); - } else { - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_valid.begin(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__( - int idx) { return child_locs[idx * ncf + ci].offset >= 0; }); + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + copy_varlen_data_kernel + <<>>( + message_data, loc_provider, total_count, d_offs.data(), d_bytes.data(), d_error.data()); } auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); @@ -1918,131 +1985,11 @@ __global__ void extract_nested_fixed_kernel(uint8_t const* message_data, /** * Copy nested variable-length data (string/bytes). */ -__global__ void copy_nested_varlen_data_kernel(uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t const* output_offsets, - char* output_data, - int num_rows, - bool has_default = false, - uint8_t const* default_data = nullptr, - int32_t default_length = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; - - char* dst = output_data + output_offsets[row]; - - if (parent_loc.offset < 0 || field_loc.offset < 0) { - if (has_default && default_length > 0) { memcpy(dst, default_data, default_length); } - return; - } - - if (field_loc.length == 0) return; - - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* src = message_data + parent_row_start + parent_loc.offset + field_loc.offset; - - memcpy(dst, src, field_loc.length); -} - -/** - * Extract nested field lengths for prefix sum. - */ -__global__ void extract_nested_lengths_kernel(field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t* lengths, - int num_rows, - bool has_default = false, - int32_t default_length = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; - - if (parent_loc.offset >= 0 && field_loc.offset >= 0) { - lengths[row] = field_loc.length; - } else if (has_default) { - lengths[row] = default_length; - } else { - lengths[row] = 0; - } -} - -/** - * Extract scalar string field lengths for prefix sum. - * For top-level STRING fields (not nested within a struct). - */ -__global__ void extract_scalar_string_lengths_kernel(field_location const* field_locations, - int field_idx, - int num_fields, - int32_t* lengths, - int num_rows, - bool has_default = false, - int32_t default_length = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& loc = field_locations[row * num_fields + field_idx]; - - if (loc.offset >= 0) { - lengths[row] = loc.length; - } else if (has_default) { - lengths[row] = default_length; - } else { - lengths[row] = 0; - } -} /** * Copy scalar string field data. * For top-level STRING fields (not nested within a struct). */ -__global__ void copy_scalar_string_data_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type row_base_offset, - field_location const* field_locations, - int field_idx, - int num_fields, - int32_t const* output_offsets, - char* output_data, - int num_rows, - bool has_default = false, - uint8_t const* default_data = nullptr, - int32_t default_length = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& loc = field_locations[row * num_fields + field_idx]; - - char* dst = output_data + output_offsets[row]; - - if (loc.offset < 0) { - // Field not found - use default if available - if (has_default && default_length > 0) { memcpy(dst, default_data, default_length); } - return; - } - - if (loc.length == 0) return; - - auto row_start = row_offsets[row] - row_base_offset; - uint8_t const* src = message_data + row_start + loc.offset; - - memcpy(dst, src, loc.length); -} // ============================================================================ // Utility functions @@ -2082,10 +2029,10 @@ std::unique_ptr make_null_column(cudf::data_type dtype, // Create empty strings column with all nulls rmm::device_uvector pairs(num_rows, stream, mr); thrust::fill(rmm::exec_policy(stream), - pairs.begin(), + pairs.data(), pairs.end(), cudf::strings::detail::string_index_pair{nullptr, 0}); - return cudf::strings::detail::make_strings_column(pairs.begin(), pairs.end(), stream, mr); + return cudf::strings::detail::make_strings_column(pairs.data(), pairs.end(), stream, mr); } case cudf::type_id::LIST: { // Create LIST with all nulls @@ -2472,14 +2419,14 @@ std::unique_ptr build_repeated_scalar_column( rmm::device_uvector counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), + d_rep_info.data(), d_rep_info.end(), - counts.begin(), + counts.data(), [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); // Set last offset = total_count CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, @@ -2504,44 +2451,25 @@ std::unique_ptr build_repeated_scalar_column( constexpr bool is_floating_point = std::is_same_v || std::is_same_v; bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); + RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; if (use_fixed_kernel) { if constexpr (sizeof(T) == 4) { - extract_repeated_fixed_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_occurrences.data(), - total_count, - values.data(), - d_error.data()); + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { - extract_repeated_fixed_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_occurrences.data(), - total_count, - values.data(), - d_error.data()); + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } } else if (zigzag) { - extract_repeated_varint_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_occurrences.data(), - total_count, - values.data(), - d_error.data()); + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { - extract_repeated_varint_kernel - <<>>(message_data, - list_offsets, - base_offset, - d_occurrences.data(), - total_count, - values.data(), - d_error.data()); + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, @@ -2648,14 +2576,14 @@ std::unique_ptr build_repeated_string_column( rmm::device_uvector counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), + d_rep_info.data(), d_rep_info.end(), - counts.begin(), + counts.data(), [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); // Set last offset = total_count CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, @@ -2668,13 +2596,14 @@ std::unique_ptr build_repeated_string_column( rmm::device_uvector str_lengths(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; - extract_repeated_lengths_kernel<<>>( - d_occurrences.data(), total_count, str_lengths.data()); + RepeatedLocationProvider loc_provider{nullptr, 0, d_occurrences.data()}; + extract_lengths_kernel + <<>>(loc_provider, total_count, str_lengths.data()); // Compute string offsets via prefix sum rmm::device_uvector str_offsets(total_count + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), str_lengths.begin(), str_lengths.end(), str_offsets.begin(), 0); + rmm::exec_policy(stream), str_lengths.data(), str_lengths.end(), str_offsets.data(), 0); int32_t total_chars = 0; CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, @@ -2698,14 +2627,12 @@ std::unique_ptr build_repeated_string_column( // Copy string data rmm::device_uvector chars(total_chars, stream, mr); + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); if (total_chars > 0) { - copy_repeated_varlen_data_kernel<<>>(message_data, - list_offsets, - base_offset, - d_occurrences.data(), - total_count, - str_offsets.data(), - chars.data()); + RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + copy_varlen_data_kernel<<>>( + message_data, loc_provider, total_count, str_offsets.data(), chars.data(), d_error.data()); } // Build the child column (either STRING or LIST) @@ -2886,14 +2813,14 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), + d_rep_info.data(), d_rep_info.end(), - counts.begin(), + counts.data(), [] __device__(repeated_field_info const& info) { return info.count; }); rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), counts.begin(), counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); // Set last offset = total_count (already computed on caller side) CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, @@ -2969,18 +2896,20 @@ std::unique_ptr build_repeated_struct_column( switch (dt.id()) { case cudf::type_id::BOOL8: { rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - extract_repeated_msg_child_varint_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_val); @@ -2991,47 +2920,53 @@ std::unique_ptr build_repeated_struct_column( } case cudf::type_id::INT32: { rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_repeated_msg_child_varint_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_repeated_msg_child_fixed_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_repeated_msg_child_varint_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -3043,47 +2978,53 @@ std::unique_ptr build_repeated_struct_column( } case cudf::type_id::INT64: { rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_repeated_msg_child_varint_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_repeated_msg_child_fixed_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else { - extract_repeated_msg_child_varint_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -3095,18 +3036,20 @@ std::unique_ptr build_repeated_struct_column( } case cudf::type_id::FLOAT32: { rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - extract_repeated_msg_child_fixed_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_float); @@ -3117,18 +3060,20 @@ std::unique_ptr build_repeated_struct_column( } case cudf::type_id::FLOAT64: { rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); double def_double = has_def ? default_floats[child_schema_idx] : 0.0; - extract_repeated_msg_child_fixed_kernel + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, total_count, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_double); @@ -3197,9 +3142,9 @@ std::unique_ptr build_repeated_struct_column( total_count); // Add base_offset back so build_nested_struct_column can subtract it thrust::transform(rmm::exec_policy(stream), - d_nested_row_offsets_i32.begin(), + d_nested_row_offsets_i32.data(), d_nested_row_offsets_i32.end(), - d_nested_row_offsets.begin(), + d_nested_row_offsets.data(), [base_offset] __device__(int32_t v) { return static_cast(v) + base_offset; }); @@ -3390,16 +3335,16 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector d_rep_counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_info.begin(), + d_rep_info.data(), d_rep_info.end(), - d_rep_counts.begin(), + d_rep_counts.data(), [] __device__(repeated_field_info const& info) { return info.count; }); int total_rep_count = - thrust::reduce(rmm::exec_policy(stream), d_rep_counts.begin(), d_rep_counts.end(), 0); + thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); if (total_rep_count == 0) { rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), list_offsets_vec.begin(), list_offsets_vec.end(), 0); + thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); auto list_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, @@ -3423,7 +3368,7 @@ std::unique_ptr build_nested_struct_column( } else { rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), d_rep_counts.begin(), d_rep_counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_rep_count, sizeof(int32_t), @@ -3447,18 +3392,14 @@ std::unique_ptr build_nested_struct_column( std::unique_ptr child_values; if (elem_type_id == cudf::type_id::INT32) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + extract_varint_kernel <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - values.data(), - d_error.data()); + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, total_rep_count, values.release(), @@ -3466,18 +3407,14 @@ std::unique_ptr build_nested_struct_column( 0); } else if (elem_type_id == cudf::type_id::INT64) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + extract_varint_kernel <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - values.data(), - d_error.data()); + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, total_rep_count, values.release(), @@ -3485,18 +3422,14 @@ std::unique_ptr build_nested_struct_column( 0); } else if (elem_type_id == cudf::type_id::BOOL8) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_varint_kernel + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + extract_varint_kernel <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - values.data(), - d_error.data()); + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, total_rep_count, values.release(), @@ -3504,18 +3437,14 @@ std::unique_ptr build_nested_struct_column( 0); } else if (elem_type_id == cudf::type_id::FLOAT32) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_fixed_kernel + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + extract_fixed_kernel <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - values.data(), - d_error.data()); + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, total_rep_count, values.release(), @@ -3523,18 +3452,14 @@ std::unique_ptr build_nested_struct_column( 0); } else if (elem_type_id == cudf::type_id::FLOAT64) { rmm::device_uvector values(total_rep_count, stream, mr); - extract_repeated_in_nested_fixed_kernel + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + extract_fixed_kernel <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, THREADS_PER_BLOCK, 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - values.data(), - d_error.data()); + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, total_rep_count, values.release(), @@ -3543,18 +3468,18 @@ std::unique_ptr build_nested_struct_column( } else if (elem_type_id == cudf::type_id::STRING) { rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_occs.begin(), + d_rep_occs.data(), d_rep_occs.end(), - d_str_lengths.begin(), + d_str_lengths.data(), [] __device__(repeated_occurrence const& occ) { return occ.length; }); int32_t total_chars = - thrust::reduce(rmm::exec_policy(stream), d_str_lengths.begin(), d_str_lengths.end(), 0); + thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), - d_str_lengths.begin(), + d_str_lengths.data(), d_str_lengths.end(), - str_offs.begin(), + str_offs.data(), 0); CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, &total_chars, @@ -3564,19 +3489,18 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - extract_repeated_in_nested_string_kernel<<<(total_rep_count + THREADS_PER_BLOCK - 1) / - THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - str_offs.data(), - chars.data(), - d_error.data()); + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + copy_varlen_data_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + loc_provider, + total_rep_count, + str_offs.data(), + chars.data(), + d_error.data()); } auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, @@ -3589,16 +3513,16 @@ std::unique_ptr build_nested_struct_column( } else if (elem_type_id == cudf::type_id::LIST) { rmm::device_uvector d_len(total_rep_count, stream, mr); thrust::transform(rmm::exec_policy(stream), - d_rep_occs.begin(), + d_rep_occs.data(), d_rep_occs.end(), - d_len.begin(), + d_len.data(), [] __device__(repeated_occurrence const& occ) { return occ.length; }); int32_t total_bytes = - thrust::reduce(rmm::exec_policy(stream), d_len.begin(), d_len.end(), 0); + thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), d_len.begin(), d_len.end(), byte_offs.begin(), 0); + rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, &total_bytes, sizeof(int32_t), @@ -3607,19 +3531,18 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector bytes(total_bytes, stream, mr); if (total_bytes > 0) { - extract_repeated_in_nested_string_kernel<<<(total_rep_count + THREADS_PER_BLOCK - 1) / - THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_rep_occs.data(), - total_rep_count, - byte_offs.data(), - bytes.data(), - d_error.data()); + NestedRepeatedLocationProvider loc_provider{ + list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; + copy_varlen_data_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + loc_provider, + total_rep_count, + byte_offs.data(), + bytes.data(), + d_error.data()); } auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, @@ -3710,19 +3633,20 @@ std::unique_ptr build_nested_struct_column( switch (dt.id()) { case cudf::type_id::BOOL8: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_val); @@ -3733,50 +3657,53 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::INT32: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -3788,35 +3715,37 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::UINT32: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -3828,50 +3757,53 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::INT64: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else { - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -3883,35 +3815,37 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::UINT64: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_nested_fixed_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -3923,19 +3857,20 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::FLOAT32: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - extract_nested_fixed_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_float); @@ -3946,19 +3881,20 @@ std::unique_ptr build_nested_struct_column( } case cudf::type_id::FLOAT64: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); double def_double = has_def ? default_floats[child_schema_idx] : 0.0; - extract_nested_fixed_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_double); @@ -3970,19 +3906,20 @@ std::unique_ptr build_nested_struct_column( case cudf::type_id::STRING: { if (enc == spark_rapids_jni::ENC_ENUM_STRING) { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - extract_nested_varint_kernel + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -4016,7 +3953,7 @@ std::unique_ptr build_nested_struct_column( int32_t cursor = 0; for (auto const& name : enum_name_bytes) { if (!name.empty()) { - std::copy(name.begin(), name.end(), h_name_chars.begin() + cursor); + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); cursor += static_cast(name.size()); } } @@ -4119,15 +4056,10 @@ std::unique_ptr build_nested_struct_column( } rmm::device_uvector lengths(num_rows, stream, mr); - extract_nested_lengths_kernel<<>>( - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - lengths.data(), - num_rows, - has_def_str, - def_len); + NestedLocationProvider loc_provider{ + nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; + extract_lengths_kernel<<>>( + loc_provider, num_rows, lengths.data(), has_def_str, def_len); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); thrust::exclusive_scan( @@ -4155,28 +4087,30 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_nested_varlen_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - output_offsets.data(), - chars.data(), - num_rows, - has_def_str, - d_default_str.data(), - def_len); + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + num_rows, + output_offsets.data(), + chars.data(), + d_error.data(), + has_def_str, + d_default_str.data(), + def_len); } - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - valid.begin(), + valid.data(), [plocs = d_parent_locs.data(), flocs = d_child_locations.data(), ci, @@ -4213,15 +4147,10 @@ std::unique_ptr build_nested_struct_column( } rmm::device_uvector lengths(num_rows, stream, mr); - extract_nested_lengths_kernel<<>>( - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - lengths.data(), - num_rows, - has_def_bytes, - def_len); + NestedLocationProvider loc_provider{ + nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; + extract_lengths_kernel<<>>( + loc_provider, num_rows, lengths.data(), has_def_bytes, def_len); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); thrust::exclusive_scan( @@ -4249,28 +4178,30 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector bytes_data(total_bytes, stream, mr); if (total_bytes > 0) { - copy_nested_varlen_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - output_offsets.data(), - bytes_data.data(), - num_rows, - has_def_bytes, - d_default_bytes.data(), - def_len); + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + num_rows, + output_offsets.data(), + bytes_data.data(), + d_error.data(), + has_def_bytes, + d_default_bytes.data(), + def_len); } - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - valid.begin(), + valid.data(), [plocs = d_parent_locs.data(), flocs = d_child_locations.data(), ci, @@ -4341,12 +4272,12 @@ std::unique_ptr build_nested_struct_column( } } - rmm::device_uvector struct_valid(num_rows, stream, mr); + rmm::device_uvector struct_valid((num_rows > 0 ? num_rows : 1), stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - struct_valid.begin(), + struct_valid.data(), [plocs = d_parent_locs.data()] __device__(auto row) { return plocs[row].offset >= 0; }); auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); return cudf::make_structs_column( @@ -4355,18 +4286,18 @@ std::unique_ptr build_nested_struct_column( } // anonymous namespace -std::unique_ptr decode_protobuf_to_struct( - cudf::column_view const& binary_input, - std::vector const& schema, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - bool fail_on_errors) +std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, + ProtobufDecodeContext const& context) { + auto const& schema = context.schema; + auto const& schema_output_types = context.schema_output_types; + auto const& default_ints = context.default_ints; + auto const& default_floats = context.default_floats; + auto const& default_bools = context.default_bools; + auto const& default_strings = context.default_strings; + auto const& enum_valid_values = context.enum_valid_values; + auto const& enum_names = context.enum_names; + bool fail_on_errors = context.fail_on_errors; CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, "binary_input must be a LIST column"); cudf::lists_column_view const in_list(binary_input); @@ -4586,18 +4517,16 @@ std::unique_ptr decode_protobuf_to_struct( switch (dt.id()) { case cudf::type_id::BOOL8: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_val = has_def ? (default_bools[schema_idx] ? 1 : 0) : 0; - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_val); @@ -4608,47 +4537,41 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::INT32: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -4679,33 +4602,29 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::UINT32: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -4717,47 +4636,41 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::INT64: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); } else { - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -4769,33 +4682,29 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::UINT64: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; if (enc == spark_rapids_jni::ENC_FIXED) { - extract_fixed_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, static_cast(def_int)); } else { - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -4807,18 +4716,16 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::FLOAT32: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); float def_float = has_def ? static_cast(default_floats[schema_idx]) : 0.0f; - extract_fixed_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_float); @@ -4829,18 +4736,16 @@ std::unique_ptr decode_protobuf_to_struct( } case cudf::type_id::FLOAT64: { rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); double def_double = has_def ? default_floats[schema_idx] : 0.0; - extract_fixed_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_fixed_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_double); @@ -4856,18 +4761,16 @@ std::unique_ptr decode_protobuf_to_struct( // 2. Validate against enum_valid_values. // 3. Convert INT32 -> UTF-8 enum name bytes. rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; - extract_varint_from_locations_kernel + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - out.data(), - valid.data(), + loc_provider, num_rows, + out.data(), + (bool*)valid.data(), d_error.data(), has_def, def_int); @@ -4903,7 +4806,7 @@ std::unique_ptr decode_protobuf_to_struct( int32_t cursor = 0; for (auto const& name : enum_name_bytes) { if (!name.empty()) { - std::copy(name.begin(), name.end(), h_name_chars.begin() + cursor); + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); cursor += static_cast(name.size()); } } @@ -5011,8 +4914,10 @@ std::unique_ptr decode_protobuf_to_struct( // Extract string lengths rmm::device_uvector lengths(num_rows, stream, mr); - extract_scalar_string_lengths_kernel<<>>( - d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_str, def_len); + TopLevelLocationProvider loc_provider{nullptr, 0, d_locations.data(), i, num_scalar}; + extract_lengths_kernel + <<>>( + loc_provider, num_rows, lengths.data(), has_def_str, def_len); // Compute offsets via prefix sum rmm::device_uvector output_offsets(num_rows + 1, stream, mr); @@ -5042,28 +4947,27 @@ std::unique_ptr decode_protobuf_to_struct( // Copy string data rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_scalar_string_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - output_offsets.data(), - chars.data(), - num_rows, - has_def_str, - d_default_str.data(), - def_len); + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + num_rows, + output_offsets.data(), + chars.data(), + d_error.data(), + has_def_str, + d_default_str.data(), + def_len); } // Build validity mask - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - valid.begin(), + valid.data(), [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { return locs[row * num_scalar + i].offset >= 0 || has_def_str; }); @@ -5095,8 +4999,9 @@ std::unique_ptr decode_protobuf_to_struct( } rmm::device_uvector lengths(num_rows, stream, mr); - extract_scalar_string_lengths_kernel<<>>( - d_locations.data(), i, num_scalar, lengths.data(), num_rows, has_def_bytes, def_len); + TopLevelLocationProvider loc_provider{nullptr, 0, d_locations.data(), i, num_scalar}; + extract_lengths_kernel<<>>( + loc_provider, num_rows, lengths.data(), has_def_bytes, def_len); rmm::device_uvector output_offsets(num_rows + 1, stream, mr); thrust::exclusive_scan( @@ -5124,27 +5029,26 @@ std::unique_ptr decode_protobuf_to_struct( rmm::device_uvector bytes_data(total_bytes, stream, mr); if (total_bytes > 0) { - copy_scalar_string_data_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_locations.data(), - i, - num_scalar, - output_offsets.data(), - bytes_data.data(), - num_rows, - has_def_bytes, - d_default_bytes.data(), - def_len); + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + num_rows, + output_offsets.data(), + bytes_data.data(), + d_error.data(), + has_def_bytes, + d_default_bytes.data(), + def_len); } - rmm::device_uvector valid(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); thrust::transform( rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - valid.begin(), + valid.data(), [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__(auto row) { return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; }); @@ -5200,7 +5104,7 @@ std::unique_ptr decode_protobuf_to_struct( thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - d_field_counts.begin(), + d_field_counts.data(), extract_strided_count{d_repeated_info.data(), ri, num_repeated}); int total_count = @@ -5218,7 +5122,7 @@ std::unique_ptr decode_protobuf_to_struct( thrust::exclusive_scan(rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), - d_occ_offsets.begin(), + d_occ_offsets.data(), 0); // Set last element CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data() + num_rows, @@ -5351,7 +5255,7 @@ std::unique_ptr decode_protobuf_to_struct( int32_t cursor = 0; for (auto const& nm : name_bytes) { if (!nm.empty()) { - std::copy(nm.begin(), nm.end(), h_name_chars.begin() + cursor); + std::copy(nm.data(), nm.data() + nm.size(), h_name_chars.data() + cursor); cursor += static_cast(nm.size()); } } @@ -5373,7 +5277,7 @@ std::unique_ptr decode_protobuf_to_struct( // 3. Validate enum values (sets row_has_invalid_enum for PERMISSIVE mode). // We also need per-element validity for string building. rmm::device_uvector elem_valid(total_count, stream, mr); - thrust::fill(rmm::exec_policy(stream), elem_valid.begin(), elem_valid.end(), true); + thrust::fill(rmm::exec_policy(stream), elem_valid.data(), elem_valid.end(), true); // validate_enum_values_kernel works on per-row basis; here we need per-element. // Binary-search each element inline via the lengths kernel below. @@ -5394,9 +5298,9 @@ std::unique_ptr decode_protobuf_to_struct( // 5. Build string offsets rmm::device_uvector str_offsets(total_count + 1, stream, mr); thrust::exclusive_scan(rmm::exec_policy(stream), - elem_lengths.begin(), + elem_lengths.data(), elem_lengths.end(), - str_offsets.begin(), + str_offsets.data(), 0); int32_t total_chars = 0; @@ -5671,7 +5575,6 @@ std::unique_ptr decode_protobuf_to_struct( } } - // Check for errors CUDF_CUDA_TRY(cudaPeekAtLastError()); int h_error = 0; CUDF_CUDA_TRY( diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index c8d94c3fc2..c95e16a2a5 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -50,6 +50,21 @@ struct nested_field_descriptor { bool has_default_value; // Whether this field has a default value }; +/** + * Context and schema information for decoding protobuf messages. + */ +struct ProtobufDecodeContext { + std::vector schema; + std::vector schema_output_types; + std::vector default_ints; + std::vector default_floats; + std::vector default_bools; + std::vector> default_strings; + std::vector> enum_valid_values; + std::vector>> enum_names; + bool fail_on_errors; +}; + /** * Decode protobuf messages (one message per row) from a LIST column into a STRUCT * column, with support for nested messages and repeated fields. @@ -75,28 +90,10 @@ struct nested_field_descriptor { * - STRUCT : protobuf nested `message` * * @param binary_input LIST column, each row is one protobuf message - * @param schema Flattened schema with parent-child relationships - * @param schema_output_types Output types for each field in schema (cudf types) - * @param default_ints Default values for int/long/enum fields - * @param default_floats Default values for float/double fields - * @param default_bools Default values for bool fields - * @param default_strings Default values for string/bytes fields - * @param enum_valid_values Valid enum values for each field (empty if not enum) - * @param enum_names Enum names for enum-as-string fields (empty if not enum-as-string), - * ordered in parallel with enum_valid_values - * @param fail_on_errors Whether to throw on malformed data + * @param context Decoding context containing schema and default values * @return STRUCT column with nested structure */ -std::unique_ptr decode_protobuf_to_struct( - cudf::column_view const& binary_input, - std::vector const& schema, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - bool fail_on_errors); +std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, + ProtobufDecodeContext const& context); } // namespace spark_rapids_jni From ca5a9212a3d0a1abbeb595ee5ef4a069b0306e63 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 25 Feb 2026 16:31:26 +0800 Subject: [PATCH 026/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 7 +++++++ src/main/cpp/src/protobuf.cu | 2 +- src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java | 6 +++--- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index de7ca90355..3c157ba9a7 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -132,11 +132,13 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, default_string_values.reserve(num_fields); for (int i = 0; i < num_fields; ++i) { jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(default_strings, i)); + if (env->ExceptionCheck()) { return 0; } if (byte_arr == nullptr) { default_string_values.emplace_back(); } else { jsize len = env->GetArrayLength(byte_arr); jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + if (bytes == nullptr) { return 0; } default_string_values.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); @@ -148,11 +150,13 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, enum_values.reserve(num_fields); for (int i = 0; i < num_fields; ++i) { jintArray int_arr = static_cast(env->GetObjectArrayElement(enum_valid_values, i)); + if (env->ExceptionCheck()) { return 0; } if (int_arr == nullptr) { enum_values.emplace_back(); } else { jsize len = env->GetArrayLength(int_arr); jint* ints = env->GetIntArrayElements(int_arr, nullptr); + if (ints == nullptr) { return 0; } enum_values.emplace_back(ints, ints + len); env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); } @@ -165,6 +169,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, enum_name_values.reserve(num_fields); for (int i = 0; i < num_fields; ++i) { jobjectArray names_arr = static_cast(env->GetObjectArrayElement(enum_names, i)); + if (env->ExceptionCheck()) { return 0; } if (names_arr == nullptr) { enum_name_values.emplace_back(); } else { @@ -173,11 +178,13 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, names_for_field.reserve(num_names); for (jsize j = 0; j < num_names; ++j) { jbyteArray name_bytes = static_cast(env->GetObjectArrayElement(names_arr, j)); + if (env->ExceptionCheck()) { return 0; } if (name_bytes == nullptr) { names_for_field.emplace_back(); } else { jsize len = env->GetArrayLength(name_bytes); jbyte* bytes = env->GetByteArrayElements(name_bytes, nullptr); + if (bytes == nullptr) { return 0; } names_for_field.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(name_bytes, bytes, JNI_ABORT); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a7d42e0967..72856b3dfc 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -3252,7 +3252,7 @@ std::unique_ptr build_nested_struct_column( } auto const threads = THREADS_PER_BLOCK; - auto const blocks = (num_rows + threads - 1) / threads; + auto const blocks = static_cast((num_rows + threads - 1) / threads); int num_child_fields = static_cast(child_field_indices.size()); std::vector h_child_field_descs(num_child_fields); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 03ead2f4a1..e97a38f452 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -134,12 +134,12 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, throw new IllegalArgumentException("All arrays must have the same length"); } - // Validate field numbers are positive + // Validate field numbers are positive and within protobuf spec range for (int i = 0; i < fieldNumbers.length; i++) { - if (fieldNumbers[i] <= 0) { + if (fieldNumbers[i] <= 0 || fieldNumbers[i] > 536870911) { throw new IllegalArgumentException( "Invalid field number at index " + i + ": " + fieldNumbers[i] + - " (field numbers must be positive)"); + " (field numbers must be 1-536870911)"); } } From edbfd98e75c868710fcc9ef65ad0bc7b8c578485 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 25 Feb 2026 18:00:36 +0800 Subject: [PATCH 027/107] ai self review and comment addressed Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 41 +- src/main/cpp/src/protobuf.cu | 2264 +++++++---------- src/main/cpp/src/protobuf.hpp | 5 +- .../nvidia/spark/rapids/jni/ProtobufTest.java | 142 ++ 4 files changed, 1150 insertions(+), 1302 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 3c157ba9a7..af6f9c6813 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -19,6 +19,7 @@ #include "protobuf.hpp" #include +#include #include extern "C" { @@ -94,6 +95,43 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, 0); } + // Validate schema topology and wire types: + // - parent index must be -1 or a prior field index + // - depth must be 0 for top-level and parent_depth + 1 for children + // - wire type must be one of {0, 1, 2, 5} + for (int i = 0; i < num_fields; ++i) { + auto const parent_idx = n_parent_indices[i]; + auto const depth = n_depth_levels[i]; + auto const wire_type = n_wire_types[i]; + + if (!(wire_type == 0 || wire_type == 1 || wire_type == 2 || wire_type == 5)) { + JNI_THROW_NEW( + env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "wire_types must be one of {0,1,2,5}", 0); + } + + if (parent_idx < -1 || parent_idx >= num_fields || parent_idx >= i) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "parent_indices must be -1 or a valid prior field index", + 0); + } + + if (parent_idx == -1) { + if (depth != 0) { + JNI_THROW_NEW( + env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "top-level fields must have depth 0", 0); + } + } else { + auto const parent_depth = n_depth_levels[parent_idx]; + if (depth != parent_depth + 1) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "child depth must equal parent depth + 1", + 0); + } + } + } + // Build schema descriptors std::vector schema; schema.reserve(num_fields); @@ -204,7 +242,8 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, std::move(enum_name_values), static_cast(fail_on_errors)}; - auto result = spark_rapids_jni::decode_protobuf_to_struct(*input, context); + auto result = + spark_rapids_jni::decode_protobuf_to_struct(*input, context, cudf::get_default_stream()); return cudf::jni::release_as_jlong(result); } diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 72856b3dfc..37a95f1f8a 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -118,6 +118,21 @@ struct device_nested_field_descriptor { bool is_repeated; bool is_required; bool has_default_value; + + device_nested_field_descriptor() = default; + + explicit device_nested_field_descriptor(spark_rapids_jni::nested_field_descriptor const& src) + : field_number(src.field_number), + parent_idx(src.parent_idx), + depth(src.depth), + wire_type(src.wire_type), + output_type_id(static_cast(src.output_type)), + encoding(src.encoding), + is_repeated(src.is_repeated), + is_required(src.is_required), + has_default_value(src.has_default_value) + { + } }; // ============================================================================ @@ -372,7 +387,9 @@ __global__ void scan_all_fields_kernel( } locations[row * num_fields + f] = {data_offset, field_size}; } - // Don't break - continue to support "last one wins" semantics + // "Last one wins" is preserved across later message tags, no need to keep scanning + // descriptors for the same tag once matched. + break; } } @@ -890,17 +907,13 @@ __global__ void copy_varlen_data_kernel(uint8_t const* message_data, if (loc.offset < 0) { if (has_default && default_len > 0) { - for (int i = 0; i < default_len; i++) { - output_chars[out_start + i] = static_cast(default_chars[i]); - } + memcpy(output_chars + out_start, default_chars, default_len); } return; } uint8_t const* src = message_data + data_offset; - for (int i = 0; i < loc.length; i++) { - output_chars[out_start + i] = static_cast(src[i]); - } + memcpy(output_chars + out_start, src, loc.length); } /** @@ -1194,6 +1207,132 @@ inline std::pair make_null_mask_from_valid( return cudf::detail::valid_if(begin, end, pred, stream, mr); } +inline void build_offsets_from_lengths(rmm::device_uvector const& lengths, + rmm::device_uvector& offsets, + rmm::cuda_stream_view stream) +{ + CUDF_EXPECTS(offsets.size() == lengths.size() + 1, "offsets size must equal lengths size + 1"); + CUDF_CUDA_TRY(cudaMemsetAsync(offsets.data(), 0, sizeof(int32_t), stream.value())); + if (lengths.size() > 0) { + thrust::inclusive_scan( + rmm::exec_policy(stream), lengths.begin(), lengths.end(), offsets.begin() + 1); + } +} + +template +std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, + int num_rows, + LaunchFn&& launch_extract, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + launch_extract(out.data(), valid.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); +} + +template +// Shared integer extractor for INT32/INT64/UINT32/UINT64 decode paths. +inline void extract_integer_into_buffers(uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + T* out_ptr, + bool* valid_ptr, + int* error_ptr, + rmm::cuda_stream_view stream) +{ + if (enable_zigzag && encoding == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } else if (encoding == spark_rapids_jni::ENC_FIXED) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } else { + static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } + } else { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } +} + +template +// Builds a scalar column for integer-like protobuf fields. +std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + rmm::device_uvector& d_error, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + return extract_and_build_scalar_column( + dt, + num_rows, + [&](T* out_ptr, bool* valid_ptr) { + extract_integer_into_buffers(message_data, + loc_provider, + num_rows, + blocks, + threads, + has_default, + default_value, + encoding, + enable_zigzag, + out_ptr, + valid_ptr, + d_error.data(), + stream); + }, + stream, + mr); +} + /** * Scan for child fields within repeated message occurrences. * Each occurrence is a protobuf message, and we need to find child field locations within it. @@ -1236,10 +1375,10 @@ __global__ void scan_repeated_message_children_kernel( // Check against child field descriptors for (int f = 0; f < num_child_fields; f++) { if (child_descs[f].field_number == fn) { - if (wt != child_descs[f].expected_wire_type) { - // Wire type mismatch - could be OK for some cases (e.g., packed vs unpacked) - // For now, just continue - continue; + bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); + if (!is_packed && wt != child_descs[f].expected_wire_type) { + atomicExch(error_flag, ERR_WIRE_TYPE); + return; } int data_offset = static_cast(cur - msg_start); @@ -1326,18 +1465,65 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, for (int ri = 0; ri < num_repeated; ri++) { int schema_idx = repeated_indices[ri]; if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int data_len = 0; - if (wt == WT_LEN) { - uint64_t len; + int expected_wt = schema[schema_idx].wire_type; + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + atomicExch(error_flag, ERR_WIRE_TYPE); + return; + } + + if (is_packed) { + uint64_t packed_len; int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { atomicExch(error_flag, ERR_VARINT); return; } - data_len = static_cast(len); + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + atomicExch(error_flag, ERR_OVERFLOW); + return; + } + + int count = 0; + if (expected_wt == WT_VARINT) { + uint8_t const* p = packed_start; + while (p < packed_end) { + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + atomicExch(error_flag, ERR_VARINT); + return; + } + p += vbytes; + count++; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + atomicExch(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 4); + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + atomicExch(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 8); + } + repeated_info[row * num_repeated + ri].count += count; + repeated_info[row * num_repeated + ri].total_length += static_cast(packed_len); + } else { + int32_t data_offset, data_len; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_len)) { + atomicExch(error_flag, ERR_FIELD_SIZE); + return; + } + repeated_info[row * num_repeated + ri].count++; + repeated_info[row * num_repeated + ri].total_length += data_len; } - repeated_info[row * num_repeated + ri].count++; - repeated_info[row * num_repeated + ri].total_length += data_len; } } @@ -1393,30 +1579,88 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, for (int ri = 0; ri < num_repeated; ri++) { int schema_idx = repeated_indices[ri]; if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int32_t data_offset = static_cast(cur - msg_start); - int32_t data_len = 0; + int expected_wt = schema[schema_idx].wire_type; + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - if (wt == WT_LEN) { - uint64_t len; + if (!is_packed && wt != expected_wt) { + atomicExch(error_flag, ERR_WIRE_TYPE); + return; + } + + if (is_packed) { + uint64_t packed_len; int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { atomicExch(error_flag, ERR_VARINT); return; } - data_offset += len_bytes; - data_len = static_cast(len); - } else if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { data_len = vbytes; } - } else if (wt == WT_32BIT) { - data_len = 4; - } else if (wt == WT_64BIT) { - data_len = 8; - } + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + atomicExch(error_flag, ERR_OVERFLOW); + return; + } + + if (expected_wt == WT_VARINT) { + uint8_t const* p = packed_start; + while (p < packed_end) { + int32_t elem_offset = static_cast(p - msg_start); + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + atomicExch(error_flag, ERR_VARINT); + return; + } + occurrences[occ_offset + occ_idx] = {row, elem_offset, vbytes}; + occ_idx++; + p += vbytes; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + atomicExch(error_flag, ERR_FIXED_LEN); + return; + } + for (uint64_t i = 0; i < packed_len; i += 4) { + occurrences[occ_offset + occ_idx] = { + row, static_cast(packed_start - msg_start + i), 4}; + occ_idx++; + } + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + atomicExch(error_flag, ERR_FIXED_LEN); + return; + } + for (uint64_t i = 0; i < packed_len; i += 8) { + occurrences[occ_offset + occ_idx] = { + row, static_cast(packed_start - msg_start + i), 8}; + occ_idx++; + } + } + } else { + int32_t data_offset = static_cast(cur - msg_start); + int32_t data_len = 0; + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + atomicExch(error_flag, ERR_VARINT); + return; + } + data_offset += len_bytes; + data_len = static_cast(len); + } else if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (read_varint(cur, msg_end, dummy, vbytes)) { data_len = vbytes; } + } else if (wt == WT_32BIT) { + data_len = 4; + } else if (wt == WT_64BIT) { + data_len = 8; + } - occurrences[occ_offset + occ_idx] = {row, data_offset, data_len}; - occ_idx++; + occurrences[occ_offset + occ_idx] = {row, data_offset, data_len}; + occ_idx++; + } } } @@ -1571,33 +1815,11 @@ inline std::unique_ptr build_repeated_msg_child_string_column( return loc.offset >= 0 ? loc.length : 0; }); - // Compute offsets via exclusive scan + // Compute offsets without host round-trip rmm::device_uvector d_str_offsets(total_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), d_str_offsets.begin(), 0); - - // Get total chars count - int32_t total_chars = 0; - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - d_str_offsets.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - d_lengths.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - - // Set final offset - CUDF_CUDA_TRY(cudaMemcpyAsync(d_str_offsets.data() + total_count, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + build_offsets_from_lengths(d_lengths, d_str_offsets, stream); + int32_t total_chars = + thrust::reduce(rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), 0); // Allocate output chars and validity rmm::device_uvector d_chars(total_chars, stream, mr); @@ -1687,28 +1909,9 @@ inline std::unique_ptr build_repeated_msg_child_bytes_column( }); rmm::device_uvector d_offs(total_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), d_offs.data(), 0); - - int32_t total_bytes = 0; - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, - d_offs.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - d_lengths.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_bytes += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(d_offs.data() + total_count, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + build_offsets_from_lengths(d_lengths, d_offs, stream); + int32_t total_bytes = + thrust::reduce(rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), 0); rmm::device_uvector d_bytes(total_bytes, stream, mr); rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); @@ -2177,6 +2380,8 @@ std::unique_ptr make_empty_struct_column_with_schema( } // namespace +namespace { + // ============================================================================ // Kernel to check required fields after scan pass // ============================================================================ @@ -2340,6 +2545,182 @@ __global__ void copy_enum_string_chars_kernel( } } +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1) / threads); + + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + validate_enum_values_kernel<<>>( + enum_values.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows); + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int32_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + h_name_offsets[k + 1] = total_name_chars; + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); + cursor += static_cast(name.size()); + } + } + + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + compute_enum_string_lengths_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + lengths.data(), + num_rows); + + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + build_offsets_from_lengths(lengths, output_offsets, stream); + int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), lengths.begin(), lengths.end(), 0); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + output_offsets.data(), + chars.data(), + num_rows); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +template +std::unique_ptr extract_and_build_string_or_bytes_column( + bool as_bytes, + uint8_t const* message_data, + int num_rows, + LengthProvider const& length_provider, + CopyProvider const& copy_provider, + ValidityFn validity_fn, + bool has_default, + std::vector const& default_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; + rmm::device_uvector d_default(def_len, stream, mr); + if (has_default && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_default.data(), default_bytes.data(), def_len, cudaMemcpyHostToDevice, stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (num_rows + threads - 1) / threads; + extract_lengths_kernel<<>>( + length_provider, num_rows, lengths.data(), has_default, def_len); + + rmm::device_uvector output_offsets(num_rows + 1, stream, mr); + build_offsets_from_lengths(lengths, output_offsets, stream); + int32_t total_size = thrust::reduce(rmm::exec_policy(stream), lengths.begin(), lengths.end(), 0); + + rmm::device_uvector chars(total_size, stream, mr); + if (total_size > 0) { + copy_varlen_data_kernel + <<>>(message_data, + copy_provider, + num_rows, + output_offsets.data(), + chars.data(), + d_error.data(), + has_default, + d_default.data(), + def_len); + } + + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.data(), + validity_fn); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + output_offsets.release(), + rmm::device_buffer{}, + 0); + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_size, + rmm::device_buffer(chars.data(), total_size, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(bytes_child), + null_count, + std::move(mask), + stream, + mr); + } + + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +} // namespace + namespace spark_rapids_jni { namespace { @@ -2408,22 +2789,18 @@ std::unique_ptr build_repeated_scalar_column( &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from counts entirely on GPU (performance fix!) - // Copy h_repeated_info to device and use thrust::transform to extract counts - rmm::device_uvector d_rep_info(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), - h_repeated_info.data(), - num_rows * sizeof(repeated_field_info), + // Build list offsets from per-row counts. + std::vector h_counts(num_rows); + for (int row = 0; row < num_rows; ++row) { + h_counts[row] = h_repeated_info[row].count; + } + rmm::device_uvector counts(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), + h_counts.data(), + num_rows * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - rmm::device_uvector counts(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.data(), - d_rep_info.end(), - counts.data(), - [] __device__(repeated_field_info const& info) { return info.count; }); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); @@ -2565,22 +2942,18 @@ std::unique_ptr build_repeated_string_column( &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from counts entirely on GPU (performance fix!) - // Copy h_repeated_info to device and use thrust::transform to extract counts - rmm::device_uvector d_rep_info(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), - h_repeated_info.data(), - num_rows * sizeof(repeated_field_info), + // Build list offsets from per-row counts. + std::vector h_counts(num_rows); + for (int row = 0; row < num_rows; ++row) { + h_counts[row] = h_repeated_info[row].count; + } + rmm::device_uvector counts(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), + h_counts.data(), + num_rows * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - rmm::device_uvector counts(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.data(), - d_rep_info.end(), - counts.data(), - [] __device__(repeated_field_info const& info) { return info.count; }); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); @@ -2602,28 +2975,9 @@ std::unique_ptr build_repeated_string_column( // Compute string offsets via prefix sum rmm::device_uvector str_offsets(total_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), str_lengths.data(), str_lengths.end(), str_offsets.data(), 0); - - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - str_offsets.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - str_lengths.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + build_offsets_from_lengths(str_lengths, str_offsets, stream); + int32_t total_chars = + thrust::reduce(rmm::exec_policy(stream), str_lengths.begin(), str_lengths.end(), 0); // Copy string data rmm::device_uvector chars(total_chars, stream, mr); @@ -2802,22 +3156,18 @@ std::unique_ptr build_repeated_struct_column( &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from counts entirely on GPU (performance fix!) - // Copy repeated_info to device and use thrust::transform to extract counts - rmm::device_uvector d_rep_info(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_info.data(), - h_repeated_info.data(), - num_rows * sizeof(repeated_field_info), + // Build list offsets from per-row counts. + std::vector h_counts(num_rows); + for (int row = 0; row < num_rows; ++row) { + h_counts[row] = h_repeated_info[row].count; + } + rmm::device_uvector counts(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), + h_counts.data(), + num_rows * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - rmm::device_uvector counts(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.data(), - d_rep_info.end(), - counts.data(), - [] __device__(repeated_field_info const& info) { return info.count; }); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); @@ -2862,8 +3212,8 @@ std::unique_ptr build_repeated_struct_column( // Scan for child fields within each message occurrence rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + // Reuse top-level error flag so failfast can observe nested repeated-message failures. + auto& d_error = d_error_top; auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; @@ -2895,8 +3245,6 @@ std::unique_ptr build_repeated_struct_column( switch (dt.id()) { case cudf::type_id::BOOL8: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), 0, @@ -2904,139 +3252,71 @@ std::unique_ptr build_repeated_struct_column( d_child_locs.data(), ci, num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, total_count, out.release(), std::move(mask), null_count)); + struct_children.push_back(extract_and_build_scalar_column( + dt, + total_count, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + total_count, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_val); + }, + stream, + mr)); break; } case cudf::type_id::INT32: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, total_count, out.release(), std::move(mask), null_count)); + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + struct_children.push_back(extract_and_build_integer_column(dt, + message_data, + loc_provider, + total_count, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + true, + stream, + mr)); break; } case cudf::type_id::INT64: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, total_count, out.release(), std::move(mask), null_count)); + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + struct_children.push_back(extract_and_build_integer_column(dt, + message_data, + loc_provider, + total_count, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + true, + stream, + mr)); break; } case cudf::type_id::FLOAT32: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), 0, @@ -3044,23 +3324,25 @@ std::unique_ptr build_repeated_struct_column( d_child_locs.data(), ci, num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_float); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, total_count, out.release(), std::move(mask), null_count)); + struct_children.push_back(extract_and_build_scalar_column( + dt, + total_count, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + total_count, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_float); + }, + stream, + mr)); break; } case cudf::type_id::FLOAT64: { - rmm::device_uvector out(total_count, stream, mr); - rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); double def_double = has_def ? default_floats[child_schema_idx] : 0.0; RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), 0, @@ -3068,18 +3350,22 @@ std::unique_ptr build_repeated_struct_column( d_child_locs.data(), ci, num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - total_count, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_double); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back(std::make_unique( - dt, total_count, out.release(), std::move(mask), null_count)); + struct_children.push_back(extract_and_build_scalar_column( + dt, + total_count, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + total_count, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_double); + }, + stream, + mr)); break; } case cudf::type_id::STRING: { @@ -3632,8 +3918,6 @@ std::unique_ptr build_nested_struct_column( switch (dt.id()) { case cudf::type_id::BOOL8: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; NestedLocationProvider loc_provider{list_offsets, base_offset, @@ -3641,223 +3925,117 @@ std::unique_ptr build_nested_struct_column( d_child_locations.data(), ci, num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back(extract_and_build_scalar_column( + dt, + num_rows, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_val); + }, + stream, + mr)); break; } case cudf::type_id::INT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + struct_children.push_back(extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + true, + stream, + mr)); break; } case cudf::type_id::UINT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_FIXED) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + struct_children.push_back(extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + false, + stream, + mr)); break; } case cudf::type_id::INT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + struct_children.push_back(extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + true, + stream, + mr)); break; } case cudf::type_id::UINT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_FIXED) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + struct_children.push_back(extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + false, + stream, + mr)); break; } case cudf::type_id::FLOAT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; NestedLocationProvider loc_provider{list_offsets, base_offset, @@ -3865,23 +4043,25 @@ std::unique_ptr build_nested_struct_column( d_child_locations.data(), ci, num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_float); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back(extract_and_build_scalar_column( + dt, + num_rows, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_float); + }, + stream, + mr)); break; } case cudf::type_id::FLOAT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); double def_double = has_def ? default_floats[child_schema_idx] : 0.0; NestedLocationProvider loc_provider{list_offsets, base_offset, @@ -3889,18 +4069,22 @@ std::unique_ptr build_nested_struct_column( d_child_locations.data(), ci, num_child_fields}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_double); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - struct_children.push_back( - std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count)); + struct_children.push_back(extract_and_build_scalar_column( + dt, + num_rows, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_double); + }, + stream, + mr)); break; } case cudf::type_id::STRING: { @@ -3919,7 +4103,7 @@ std::unique_ptr build_nested_struct_column( loc_provider, num_rows, out.data(), - (bool*)valid.data(), + valid.data(), d_error.data(), has_def, def_int); @@ -3929,110 +4113,14 @@ std::unique_ptr build_nested_struct_column( auto const& valid_enums = enum_valid_values[child_schema_idx]; auto const& enum_name_bytes = enum_names[child_schema_idx]; if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_rows); - - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int32_t total_name_chars = 0; - for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - h_name_offsets[k + 1] = total_name_chars; - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& name : enum_name_bytes) { - if (!name.empty()) { - std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); - cursor += static_cast(name.size()); - } - } - - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - compute_enum_string_lengths_kernel<<>>( - out.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - lengths.data(), - num_rows); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - lengths.begin(), - lengths.end(), - output_offsets.begin(), - 0); - - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - output_offsets.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - out.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - output_offsets.data(), - chars.data(), - num_rows); - } - - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - struct_children.push_back(cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); + struct_children.push_back(build_enum_string_column(out, + valid, + valid_enums, + enum_name_bytes, + d_row_has_invalid_enum, + num_rows, + stream, + mr)); } else { CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); @@ -4044,90 +4132,33 @@ std::unique_ptr build_nested_struct_column( } else { bool has_def_str = has_def; auto const& def_str = default_strings[child_schema_idx]; - int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; - - rmm::device_uvector d_default_str(def_len, stream, mr); - if (has_def_str && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), - def_str.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - NestedLocationProvider loc_provider{ + NestedLocationProvider len_provider{ nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; - extract_lengths_kernel<<>>( - loc_provider, num_rows, lengths.data(), has_def_str, def_len); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - output_offsets.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - num_rows, - output_offsets.data(), - chars.data(), - d_error.data(), - has_def_str, - d_default_str.data(), - def_len); - } - - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.data(), - [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, - num_child_fields, - has_def_str] __device__(auto row) { - return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || - has_def_str; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - struct_children.push_back(cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask))); + NestedLocationProvider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + auto valid_fn = [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_str] __device__(cudf::size_type row) { + return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + has_def_str; + }; + struct_children.push_back(extract_and_build_string_or_bytes_column(false, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_str, + def_str, + d_error, + stream, + mr)); } break; } @@ -4135,101 +4166,33 @@ std::unique_ptr build_nested_struct_column( // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; auto const& def_bytes = default_strings[child_schema_idx]; - int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; - - rmm::device_uvector d_default_bytes(def_len, stream, mr); - if (has_def_bytes && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), - def_bytes.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - NestedLocationProvider loc_provider{ + NestedLocationProvider len_provider{ nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; - extract_lengths_kernel<<>>( - loc_provider, num_rows, lengths.data(), has_def_bytes, def_len); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - - int32_t total_bytes = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, - output_offsets.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_bytes += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector bytes_data(total_bytes, stream, mr); - if (total_bytes > 0) { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - num_rows, - output_offsets.data(), - bytes_data.data(), - d_error.data(), - has_def_bytes, - d_default_bytes.data(), - def_len); - } - - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.data(), - [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, - num_child_fields, - has_def_bytes] __device__(auto row) { - return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || - has_def_bytes; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); - struct_children.push_back(cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(bytes_child), - null_count, - std::move(mask), - stream, - mr)); + NestedLocationProvider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + auto valid_fn = [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_bytes] __device__(cudf::size_type row) { + return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + has_def_bytes; + }; + struct_children.push_back(extract_and_build_string_or_bytes_column(true, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_bytes, + def_bytes, + d_error, + stream, + mr)); break; } case cudf::type_id::STRUCT: { @@ -4287,7 +4250,8 @@ std::unique_ptr build_nested_struct_column( } // anonymous namespace std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, - ProtobufDecodeContext const& context) + ProtobufDecodeContext const& context, + rmm::cuda_stream_view stream) { auto const& schema = context.schema; auto const& schema_output_types = context.schema_output_types; @@ -4305,10 +4269,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, "binary_input must be a LIST column"); - auto const stream = cudf::get_default_stream(); - auto mr = cudf::get_current_device_resource_ref(); - auto num_rows = binary_input.size(); - auto num_fields = static_cast(schema.size()); + auto mr = cudf::get_current_device_resource_ref(); + auto num_rows = binary_input.size(); + auto num_fields = static_cast(schema.size()); if (num_rows == 0 || num_fields == 0) { // Build empty struct based on top-level fields with proper nested structure @@ -4349,15 +4312,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Copy schema to device std::vector h_device_schema(num_fields); for (int i = 0; i < num_fields; i++) { - h_device_schema[i] = {schema[i].field_number, - schema[i].parent_idx, - schema[i].depth, - schema[i].wire_type, - static_cast(schema[i].output_type), - schema[i].encoding, - schema[i].is_repeated, - schema[i].is_required, - schema[i].has_default_value}; + h_device_schema[i] = device_nested_field_descriptor{schema[i]}; } rmm::device_uvector d_schema(num_fields, stream, mr); @@ -4393,6 +4348,15 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Error flag rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + auto check_error_and_throw = [&]() { + if (!fail_on_errors) return; + int h_error = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync( + &h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + CUDF_EXPECTS(h_error == 0, + "Malformed protobuf message, unsupported wire type, or missing required field"); + }; // Enum validation support (PERMISSIVE mode) bool has_enum_fields = std::any_of( @@ -4443,6 +4407,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_nested, d_nested_indices.data(), d_error.data()); + check_error_and_throw(); } // For scalar fields at depth 0, use the existing scan_all_fields_kernel @@ -4470,6 +4435,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& scan_all_fields_kernel<<>>( *d_in, d_field_descs.data(), num_scalar, d_locations.data(), d_error.data()); + check_error_and_throw(); // Check required fields (after scan pass) { @@ -4494,6 +4460,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.value())); check_required_fields_kernel<<>>( d_locations.data(), d_is_required.data(), num_scalar, num_rows, d_error.data()); + check_error_and_throw(); } } @@ -4516,66 +4483,46 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& switch (dt.id()) { case cudf::type_id::BOOL8: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_val = has_def ? (default_bools[schema_idx] ? 1 : 0) : 0; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_val); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = extract_and_build_scalar_column( + dt, + num_rows, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_val); + }, + stream, + mr); break; } case cudf::type_id::INT32: { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_integer_into_buffers(message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + def_int, + enc, + true, + out.data(), + valid.data(), + d_error.data(), + stream); // Enum validation: check if this INT32 field has valid enum values if (schema_idx < static_cast(enum_valid_values.size())) { auto const& valid_enums = enum_valid_values[schema_idx]; @@ -4601,157 +4548,104 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& break; } case cudf::type_id::UINT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_FIXED) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + column_map[schema_idx] = extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + false, + stream, + mr); break; } case cudf::type_id::INT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_ZIGZAG) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else if (enc == spark_rapids_jni::ENC_FIXED) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } else { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + column_map[schema_idx] = extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + true, + stream, + mr); break; } case cudf::type_id::UINT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[schema_idx] : 0; - if (enc == spark_rapids_jni::ENC_FIXED) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - static_cast(def_int)); - } else { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_int); - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + column_map[schema_idx] = extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_rows, + blocks, + threads, + d_error, + has_def, + def_int, + enc, + false, + stream, + mr); break; } case cudf::type_id::FLOAT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); float def_float = has_def ? static_cast(default_floats[schema_idx]) : 0.0f; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_float); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = extract_and_build_scalar_column( + dt, + num_rows, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_float); + }, + stream, + mr); break; } case cudf::type_id::FLOAT64: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); double def_double = has_def ? default_floats[schema_idx] : 0.0; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - (bool*)valid.data(), - d_error.data(), - has_def, - def_double); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); + column_map[schema_idx] = extract_and_build_scalar_column( + dt, + num_rows, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + d_error.data(), + has_def, + def_double); + }, + stream, + mr); break; } case cudf::type_id::STRING: { @@ -4770,7 +4664,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& loc_provider, num_rows, out.data(), - (bool*)valid.data(), + valid.data(), d_error.data(), has_def, def_int); @@ -4780,114 +4674,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto const& valid_enums = enum_valid_values[schema_idx]; auto const& enum_name_bytes = enum_names[schema_idx]; if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { - // Validate enum numeric values first. - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_rows); - - // Build flattened enum-name chars and offsets on host, then copy to device. - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int32_t total_name_chars = 0; - for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - h_name_offsets[k + 1] = total_name_chars; - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& name : enum_name_bytes) { - if (!name.empty()) { - std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); - cursor += static_cast(name.size()); - } - } - - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - // Compute output UTF-8 lengths - rmm::device_uvector lengths(num_rows, stream, mr); - compute_enum_string_lengths_kernel<<>>( - out.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - lengths.data(), - num_rows); - - // Prefix sum for string offsets - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - lengths.begin(), - lengths.end(), - output_offsets.begin(), - 0); - - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - output_offsets.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - out.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - output_offsets.data(), - chars.data(), - num_rows); - } - - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - column_map[schema_idx] = cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + column_map[schema_idx] = build_enum_string_column(out, + valid, + valid_enums, + enum_name_bytes, + d_row_has_invalid_enum, + num_rows, + stream, + mr); } else { // Missing enum metadata for enum-as-string field; mark as decode error. CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); @@ -4901,85 +4695,24 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Regular protobuf STRING (length-delimited) bool has_def_str = has_def; auto const& def_str = default_strings[schema_idx]; - int32_t def_len = has_def_str ? static_cast(def_str.size()) : 0; - - rmm::device_uvector d_default_str(def_len, stream, mr); - if (has_def_str && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_str.data(), - def_str.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); - } - - // Extract string lengths - rmm::device_uvector lengths(num_rows, stream, mr); - TopLevelLocationProvider loc_provider{nullptr, 0, d_locations.data(), i, num_scalar}; - extract_lengths_kernel - <<>>( - loc_provider, num_rows, lengths.data(), has_def_str, def_len); - - // Compute offsets via prefix sum - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - - int32_t total_chars = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - output_offsets.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Copy string data - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - num_rows, - output_offsets.data(), - chars.data(), - d_error.data(), - has_def_str, - d_default_str.data(), - def_len); - } - - // Build validity mask - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.data(), - [locs = d_locations.data(), i, num_scalar, has_def_str] __device__(auto row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_str; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - column_map[schema_idx] = cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + TopLevelLocationProvider len_provider{nullptr, 0, d_locations.data(), i, num_scalar}; + TopLevelLocationProvider copy_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_str] __device__( + cudf::size_type row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_str; + }; + column_map[schema_idx] = extract_and_build_string_or_bytes_column(false, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_str, + def_str, + d_error, + stream, + mr); } break; } @@ -4987,91 +4720,24 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; auto const& def_bytes = default_strings[schema_idx]; - int32_t def_len = has_def_bytes ? static_cast(def_bytes.size()) : 0; - - rmm::device_uvector d_default_bytes(def_len, stream, mr); - if (has_def_bytes && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_default_bytes.data(), - def_bytes.data(), - def_len, - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - TopLevelLocationProvider loc_provider{nullptr, 0, d_locations.data(), i, num_scalar}; - extract_lengths_kernel<<>>( - loc_provider, num_rows, lengths.data(), has_def_bytes, def_len); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), output_offsets.begin(), 0); - - int32_t total_bytes = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_bytes, - output_offsets.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - lengths.data() + num_rows - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_bytes += last_len; - CUDF_CUDA_TRY(cudaMemcpyAsync(output_offsets.data() + num_rows, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector bytes_data(total_bytes, stream, mr); - if (total_bytes > 0) { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - num_rows, - output_offsets.data(), - bytes_data.data(), - d_error.data(), - has_def_bytes, - d_default_bytes.data(), - def_len); - } - - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.data(), - [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__(auto row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; - }); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(bytes_data.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); - column_map[schema_idx] = cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(bytes_child), - null_count, - std::move(mask), - stream, - mr); + TopLevelLocationProvider len_provider{nullptr, 0, d_locations.data(), i, num_scalar}; + TopLevelLocationProvider copy_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__( + cudf::size_type row) { + return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; + }; + column_map[schema_idx] = extract_and_build_string_or_bytes_column(true, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_bytes, + def_bytes, + d_error, + stream, + mr); break; } default: @@ -5171,6 +4837,28 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream, mr); break; + case cudf::type_id::UINT32: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); + break; + case cudf::type_id::UINT64: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + h_device_schema[schema_idx], + field_info, + d_occurrences, + total_count, + num_rows, + stream, + mr); + break; case cudf::type_id::FLOAT32: column_map[schema_idx] = build_repeated_scalar_column(binary_input, @@ -5297,33 +4985,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // 5. Build string offsets rmm::device_uvector str_offsets(total_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - elem_lengths.data(), - elem_lengths.end(), - str_offsets.data(), - 0); - - int32_t total_chars = 0; - if (total_count > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(&total_chars, - str_offsets.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - int32_t last_len = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(&last_len, - elem_lengths.data() + total_count - 1, - sizeof(int32_t), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - total_chars += last_len; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offsets.data() + total_count, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + build_offsets_from_lengths(elem_lengths, str_offsets, stream); + int32_t total_chars = thrust::reduce( + rmm::exec_policy(stream), elem_lengths.begin(), elem_lengths.end(), 0); // 6. Copy string chars rmm::device_uvector chars(total_chars, stream, mr); diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index c95e16a2a5..a42ec6e80a 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -20,6 +20,8 @@ #include #include +#include + #include #include @@ -94,6 +96,7 @@ struct ProtobufDecodeContext { * @return STRUCT column with nested structure */ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, - ProtobufDecodeContext const& context); + ProtobufDecodeContext const& context, + rmm::cuda_stream_view stream); } // namespace spark_rapids_jni diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 51054c2f54..3ab1b8c988 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1885,6 +1885,148 @@ void testDeepNestedMessageDepth3() { } } + @Test + void testPackedRepeatedInsideNestedMessage() { + // message Inner { repeated int32 ids = 1 [packed=true]; } + // message Outer { Inner inner = 1; } + byte[] packedIds = concatBytes(encodeVarint(10), encodeVarint(20), encodeVarint(30)); + Byte[] inner = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedIds.length)), + box(packedIds)); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(inner.length)), + inner); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 1}, // outer.inner, inner.ids + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, true}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + false)) { + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector innerStruct = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector idsList = innerStruct.getChildColumnView(0).copyToColumnVector(); + ColumnVector ids = idsList.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostIds = ids.copyToHost()) { + assertEquals(3, ids.getRowCount()); + assertEquals(10, hostIds.getInt(0)); + assertEquals(20, hostIds.getInt(1)); + assertEquals(30, hostIds.getInt(2)); + } + } + } + + @Test + void testRepeatedUint32() { + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(1, WT_VARINT)), box(encodeVarint(3))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.UINT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector vals = list.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.UINT32, vals.getType()); + assertEquals(3, vals.getRowCount()); + } + } + } + + @Test + void testRepeatedUint64() { + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(11)), + box(tag(1, WT_VARINT)), box(encodeVarint(22)), + box(tag(1, WT_VARINT)), box(encodeVarint(33))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.UINT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector vals = list.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.UINT64, vals.getType()); + assertEquals(3, vals.getRowCount()); + } + } + } + + @Test + void testWireTypeMismatchInRepeatedMessageChildFailfast() { + // message Item { int32 x = 1; } message Outer { repeated Item items = 1; } + // Encode x with WT_64BIT instead of WT_VARINT to force hard mismatch. + Byte[] badItem = concat(box(tag(1, WT_64BIT)), box(encodeFixed64(123L))); + Byte[] row = concat(box(tag(1, WT_LEN)), box(encodeVarint(badItem.length)), badItem); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector ignored = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + true)) { + } + }); + } + } + // ============================================================================ // FAILFAST Mode Tests (failOnErrors = true) // ============================================================================ From 34bcf0ba9cfcdb01b98010f12c06c439c3e18536 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 26 Feb 2026 16:47:50 +0800 Subject: [PATCH 028/107] ai self review and comment addressed Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 118 ++++++++---------- .../com/nvidia/spark/rapids/jni/Protobuf.java | 5 +- 2 files changed, 55 insertions(+), 68 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 37a95f1f8a..1e5d51f809 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -49,6 +49,8 @@ namespace { constexpr int WT_VARINT = 0; constexpr int WT_64BIT = 1; constexpr int WT_LEN = 2; +constexpr int WT_SGROUP = 3; +constexpr int WT_EGROUP = 4; constexpr int WT_32BIT = 5; // Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. @@ -163,8 +165,17 @@ __device__ inline bool read_varint(uint8_t const* cur, return false; } +__device__ inline void set_error_once(int* error_flag, int error_code) +{ + atomicCAS(error_flag, 0, error_code); +} + +// Keep call sites concise while ensuring first-error-wins semantics. +#define atomicExch(error_flag, error_code) set_error_once(error_flag, error_code) + __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { + auto const* start = cur; switch (wt) { case WT_VARINT: { // Need to scan to find the end of varint @@ -191,6 +202,25 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return -1; return n + static_cast(len); } + case WT_SGROUP: { + // Recursively skip until the matching end-group tag. + while (cur < end) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) return -1; + cur += key_bytes; + + int inner_wt = static_cast(key & 0x7); + if (inner_wt == WT_EGROUP) { return static_cast(cur - start); } + + int inner_size = get_wire_type_size(inner_wt, cur, end); + if (inner_size < 0 || cur + inner_size > end) return -1; + cur += inner_size; + } + return -1; + } + case WT_EGROUP: + return 0; default: return -1; } } @@ -200,6 +230,12 @@ __device__ inline bool skip_field(uint8_t const* cur, int wt, uint8_t const*& out_cur) { + // End-group is handled by the parent group parser. + if (wt == WT_EGROUP) { + out_cur = cur; + return true; + } + int size = get_wire_type_size(wt, cur, end); if (size < 0) return false; // Ensure we don't skip past the end of the buffer @@ -2732,7 +2768,7 @@ template std::unique_ptr build_repeated_scalar_column( cudf::column_view const& binary_input, device_nested_field_descriptor const& field_desc, - std::vector const& h_repeated_info, + rmm::device_uvector const& d_field_counts, rmm::device_uvector& d_occurrences, int total_count, int num_rows, @@ -2789,21 +2825,9 @@ std::unique_ptr build_repeated_scalar_column( &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from per-row counts. - std::vector h_counts(num_rows); - for (int row = 0; row < num_rows; ++row) { - h_counts[row] = h_repeated_info[row].count; - } - rmm::device_uvector counts(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), - h_counts.data(), - num_rows * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); // Set last offset = total_count CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, @@ -2885,7 +2909,7 @@ std::unique_ptr build_repeated_scalar_column( std::unique_ptr build_repeated_string_column( cudf::column_view const& binary_input, device_nested_field_descriptor const& field_desc, - std::vector const& h_repeated_info, + rmm::device_uvector const& d_field_counts, rmm::device_uvector& d_occurrences, int total_count, int num_rows, @@ -2942,21 +2966,9 @@ std::unique_ptr build_repeated_string_column( &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from per-row counts. - std::vector h_counts(num_rows); - for (int row = 0; row < num_rows; ++row) { - h_counts[row] = h_repeated_info[row].count; - } - rmm::device_uvector counts(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), - h_counts.data(), - num_rows * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); // Set last offset = total_count CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, @@ -3078,7 +3090,7 @@ std::unique_ptr build_nested_struct_column( std::unique_ptr build_repeated_struct_column( cudf::column_view const& binary_input, device_nested_field_descriptor const& field_desc, - std::vector const& h_repeated_info, + rmm::device_uvector const& d_field_counts, rmm::device_uvector& d_occurrences, int total_count, int num_rows, @@ -3156,21 +3168,9 @@ std::unique_ptr build_repeated_struct_column( &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - // Build list offsets from per-row counts. - std::vector h_counts(num_rows); - for (int row = 0; row < num_rows; ++row) { - h_counts[row] = h_repeated_info[row].count; - } - rmm::device_uvector counts(num_rows, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(counts.data(), - h_counts.data(), - num_rows * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), counts.data(), counts.end(), list_offs.begin(), 0); + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); // Set last offset = total_count (already computed on caller side) CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, @@ -4750,14 +4750,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Process repeated fields if (num_repeated > 0) { - std::vector h_repeated_info(static_cast(num_rows) * num_repeated); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_repeated_info.data(), - d_repeated_info.data(), - h_repeated_info.size() * sizeof(repeated_field_info), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - cudf::lists_column_view const in_list_view(binary_input); auto const* list_offsets = in_list_view.offsets().data(); @@ -4776,12 +4768,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int total_count = thrust::reduce(rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), 0); - // Still need host-side field_info for build_repeated_scalar_column - std::vector field_info(num_rows); - for (int row = 0; row < num_rows; row++) { - field_info[row] = h_repeated_info[row * num_repeated + ri]; - } - if (total_count > 0) { // Build offsets for occurrence scanning on GPU (performance fix!) rmm::device_uvector d_occ_offsets(num_rows + 1, stream, mr); @@ -4819,7 +4805,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -4830,7 +4816,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -4841,7 +4827,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -4852,7 +4838,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -4863,7 +4849,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -4874,7 +4860,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -4885,7 +4871,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& column_map[schema_idx] = build_repeated_scalar_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -5056,7 +5042,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { column_map[schema_idx] = build_repeated_string_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -5069,7 +5055,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::LIST: // bytes as LIST column_map[schema_idx] = build_repeated_string_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, @@ -5086,7 +5072,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { column_map[schema_idx] = build_repeated_struct_column(binary_input, h_device_schema[schema_idx], - field_info, + d_field_counts, d_occurrences, total_count, num_rows, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index e97a38f452..df807a0ba8 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -60,6 +60,7 @@ public class Protobuf { public static final int WT_64BIT = 1; public static final int WT_LEN = 2; public static final int WT_32BIT = 5; + private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; /** * Decode protobuf messages into a STRUCT column using a flattened schema representation. @@ -136,10 +137,10 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, // Validate field numbers are positive and within protobuf spec range for (int i = 0; i < fieldNumbers.length; i++) { - if (fieldNumbers[i] <= 0 || fieldNumbers[i] > 536870911) { + if (fieldNumbers[i] <= 0 || fieldNumbers[i] > MAX_FIELD_NUMBER) { throw new IllegalArgumentException( "Invalid field number at index " + i + ": " + fieldNumbers[i] + - " (field numbers must be 1-536870911)"); + " (field numbers must be 1-" + MAX_FIELD_NUMBER + ")"); } } From 632448b9d505dc453c3c3c739ee31e5535d3f090 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 27 Feb 2026 10:37:01 +0800 Subject: [PATCH 029/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 123 +++++++++++++++++------------------ 1 file changed, 60 insertions(+), 63 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 1e5d51f809..e43223b67a 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -170,12 +170,8 @@ __device__ inline void set_error_once(int* error_flag, int error_code) atomicCAS(error_flag, 0, error_code); } -// Keep call sites concise while ensuring first-error-wins semantics. -#define atomicExch(error_flag, error_code) set_error_once(error_flag, error_code) - __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { - auto const* start = cur; switch (wt) { case WT_VARINT: { // Need to scan to find the end of varint @@ -203,6 +199,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return n + static_cast(len); } case WT_SGROUP: { + auto const* start = cur; // Recursively skip until the matching end-group tag. while (cur < end) { uint64_t key; @@ -278,7 +275,7 @@ __device__ inline bool check_message_bounds(int32_t start, int* error_flag) { if (start < 0 || end_pos < start || end_pos > total_size) { - atomicExch(error_flag, ERR_BOUNDS); + set_error_once(error_flag, ERR_BOUNDS); return false; } return true; @@ -297,7 +294,7 @@ __device__ inline bool decode_tag(uint8_t const*& cur, uint64_t key; int key_bytes; if (!read_varint(cur, end, key, key_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return false; } @@ -305,7 +302,7 @@ __device__ inline bool decode_tag(uint8_t const*& cur, tag.field_number = static_cast(key >> 3); tag.wire_type = static_cast(key & 0x7); if (tag.field_number == 0) { - atomicExch(error_flag, ERR_FIELD_NUMBER); + set_error_once(error_flag, ERR_FIELD_NUMBER); return false; } return true; @@ -392,7 +389,7 @@ __global__ void scan_all_fields_kernel( if (field_descs[f].field_number == fn) { // Check wire type matches if (wt != field_descs[f].expected_wire_type) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } @@ -404,12 +401,12 @@ __global__ void scan_all_fields_kernel( uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } if (len > static_cast(msg_end - cur - len_bytes) || len > static_cast(INT_MAX)) { - atomicExch(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_OVERFLOW); return; } // Record offset pointing to the actual data (after length prefix) @@ -418,7 +415,7 @@ __global__ void scan_all_fields_kernel( // For fixed-size and varint fields, record offset and compute length int field_size = get_wire_type_size(wt, cur, msg_end); if (field_size < 0) { - atomicExch(error_flag, ERR_FIELD_SIZE); + set_error_once(error_flag, ERR_FIELD_SIZE); return; } locations[row * num_fields + f] = {data_offset, field_size}; @@ -432,7 +429,7 @@ __global__ void scan_all_fields_kernel( // Skip to next field uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -506,7 +503,7 @@ __global__ void count_repeated_fields_kernel( bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); if (!is_packed && wt != expected_wt) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } @@ -515,7 +512,7 @@ __global__ void count_repeated_fields_kernel( uint64_t packed_len; int len_bytes; if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } @@ -523,7 +520,7 @@ __global__ void count_repeated_fields_kernel( uint8_t const* packed_start = cur + len_bytes; uint8_t const* packed_end = packed_start + packed_len; if (packed_end > msg_end) { - atomicExch(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_OVERFLOW); return; } @@ -535,7 +532,7 @@ __global__ void count_repeated_fields_kernel( uint64_t dummy; int vbytes; if (!read_varint(p, packed_end, dummy, vbytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } p += vbytes; @@ -554,7 +551,7 @@ __global__ void count_repeated_fields_kernel( // Non-packed encoding: single element int32_t data_offset, data_length; if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { - atomicExch(error_flag, ERR_FIELD_SIZE); + set_error_once(error_flag, ERR_FIELD_SIZE); return; } @@ -569,14 +566,14 @@ __global__ void count_repeated_fields_kernel( int schema_idx = nested_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { if (wt != WT_LEN) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } @@ -588,7 +585,7 @@ __global__ void count_repeated_fields_kernel( // Skip to next field uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -645,14 +642,14 @@ __global__ void scan_repeated_field_occurrences_kernel( uint64_t packed_len; int len_bytes; if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } uint8_t const* packed_start = cur + len_bytes; uint8_t const* packed_end = packed_start + packed_len; if (packed_end > msg_end) { - atomicExch(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_OVERFLOW); return; } @@ -665,7 +662,7 @@ __global__ void scan_repeated_field_occurrences_kernel( uint64_t dummy; int vbytes; if (!read_varint(p, packed_end, dummy, vbytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } occurrences[write_idx] = {static_cast(row), elem_offset, vbytes}; @@ -695,7 +692,7 @@ __global__ void scan_repeated_field_occurrences_kernel( // Non-packed encoding: single element int32_t data_offset, data_length; if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { - atomicExch(error_flag, ERR_FIELD_SIZE); + set_error_once(error_flag, ERR_FIELD_SIZE); return; } @@ -708,7 +705,7 @@ __global__ void scan_repeated_field_occurrences_kernel( // Skip to next field uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -840,7 +837,7 @@ __global__ void extract_varint_kernel(uint8_t const* message_data, uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); if (valid) valid[idx] = false; return; } @@ -881,7 +878,7 @@ __global__ void extract_fixed_kernel(uint8_t const* message_data, if constexpr (WT == WT_32BIT) { if (loc.length < 4) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); if (valid) valid[idx] = false; return; } @@ -889,7 +886,7 @@ __global__ void extract_fixed_kernel(uint8_t const* message_data, memcpy(&value, &raw, sizeof(value)); } else { if (loc.length < 8) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); if (valid) valid[idx] = false; return; } @@ -994,7 +991,7 @@ __global__ void extract_varint_from_locations_kernel( uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); valid[row] = false; return; } @@ -1043,7 +1040,7 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, OutputType value; if constexpr (WT == WT_32BIT) { if (loc.length < 4) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -1051,7 +1048,7 @@ __global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, memcpy(&value, &raw, sizeof(value)); } else { if (loc.length < 8) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -1090,7 +1087,7 @@ __global__ void extract_repeated_varint_kernel(uint8_t const* message_data, uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); out[idx] = OutputType{}; return; } @@ -1121,7 +1118,7 @@ __global__ void extract_repeated_fixed_kernel(uint8_t const* message_data, OutputType value; if constexpr (WT == WT_32BIT) { if (occ.length < 4) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); out[idx] = OutputType{}; return; } @@ -1129,7 +1126,7 @@ __global__ void extract_repeated_fixed_kernel(uint8_t const* message_data, memcpy(&value, &raw, sizeof(value)); } else { if (occ.length < 8) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); out[idx] = OutputType{}; return; } @@ -1184,7 +1181,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, for (int f = 0; f < num_fields; f++) { if (field_descs[f].field_number == fn) { if (wt != field_descs[f].expected_wire_type) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } @@ -1194,12 +1191,12 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, uint64_t len; int len_bytes; if (!read_varint(cur, nested_end, len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } if (len > static_cast(nested_end - cur - len_bytes) || len > static_cast(INT_MAX)) { - atomicExch(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_OVERFLOW); return; } output_locations[row * num_fields + f] = {data_offset + len_bytes, @@ -1207,7 +1204,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, } else { int field_size = get_wire_type_size(wt, cur, nested_end); if (field_size < 0) { - atomicExch(error_flag, ERR_FIELD_SIZE); + set_error_once(error_flag, ERR_FIELD_SIZE); return; } output_locations[row * num_fields + f] = {data_offset, field_size}; @@ -1217,7 +1214,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, uint8_t const* next; if (!skip_field(cur, nested_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -1413,7 +1410,7 @@ __global__ void scan_repeated_message_children_kernel( if (child_descs[f].field_number == fn) { bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); if (!is_packed && wt != child_descs[f].expected_wire_type) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } @@ -1423,7 +1420,7 @@ __global__ void scan_repeated_message_children_kernel( uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } // Store offset (after length prefix) and length @@ -1450,7 +1447,7 @@ __global__ void scan_repeated_message_children_kernel( // Skip to next field uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -1505,7 +1502,7 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); if (!is_packed && wt != expected_wt) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } @@ -1513,13 +1510,13 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, uint64_t packed_len; int len_bytes; if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } uint8_t const* packed_start = cur + len_bytes; uint8_t const* packed_end = packed_start + packed_len; if (packed_end > msg_end) { - atomicExch(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_OVERFLOW); return; } @@ -1530,7 +1527,7 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, uint64_t dummy; int vbytes; if (!read_varint(p, packed_end, dummy, vbytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } p += vbytes; @@ -1538,13 +1535,13 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, } } else if (expected_wt == WT_32BIT) { if ((packed_len % 4) != 0) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); return; } count = static_cast(packed_len / 4); } else if (expected_wt == WT_64BIT) { if ((packed_len % 8) != 0) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); return; } count = static_cast(packed_len / 8); @@ -1554,7 +1551,7 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, } else { int32_t data_offset, data_len; if (!get_field_data_location(cur, msg_end, wt, data_offset, data_len)) { - atomicExch(error_flag, ERR_FIELD_SIZE); + set_error_once(error_flag, ERR_FIELD_SIZE); return; } repeated_info[row * num_repeated + ri].count++; @@ -1565,7 +1562,7 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -1619,7 +1616,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); if (!is_packed && wt != expected_wt) { - atomicExch(error_flag, ERR_WIRE_TYPE); + set_error_once(error_flag, ERR_WIRE_TYPE); return; } @@ -1627,13 +1624,13 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, uint64_t packed_len; int len_bytes; if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } uint8_t const* packed_start = cur + len_bytes; uint8_t const* packed_end = packed_start + packed_len; if (packed_end > msg_end) { - atomicExch(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_OVERFLOW); return; } @@ -1644,7 +1641,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, uint64_t dummy; int vbytes; if (!read_varint(p, packed_end, dummy, vbytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } occurrences[occ_offset + occ_idx] = {row, elem_offset, vbytes}; @@ -1653,7 +1650,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, } } else if (expected_wt == WT_32BIT) { if ((packed_len % 4) != 0) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); return; } for (uint64_t i = 0; i < packed_len; i += 4) { @@ -1663,7 +1660,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, } } else if (expected_wt == WT_64BIT) { if ((packed_len % 8) != 0) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); return; } for (uint64_t i = 0; i < packed_len; i += 8) { @@ -1679,7 +1676,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); return; } data_offset += len_bytes; @@ -1702,7 +1699,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { - atomicExch(error_flag, ERR_SKIP); + set_error_once(error_flag, ERR_SKIP); return; } cur = next; @@ -1752,7 +1749,7 @@ __global__ void extract_repeated_msg_child_varint_kernel(uint8_t const* message_ uint64_t val; int vbytes; if (!read_varint(cur, varint_end, val, vbytes)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); valid[idx] = false; return; } @@ -2151,7 +2148,7 @@ __global__ void extract_nested_varint_kernel(uint8_t const* message_data, uint64_t v; int n; if (!read_varint(cur, cur_end, v, n)) { - atomicExch(error_flag, ERR_VARINT); + set_error_once(error_flag, ERR_VARINT); valid[row] = false; return; } @@ -2201,7 +2198,7 @@ __global__ void extract_nested_fixed_kernel(uint8_t const* message_data, OutputType value; if constexpr (WT == WT_32BIT) { if (field_loc.length < 4) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -2209,7 +2206,7 @@ __global__ void extract_nested_fixed_kernel(uint8_t const* message_data, memcpy(&value, &raw, sizeof(value)); } else { if (field_loc.length < 8) { - atomicExch(error_flag, ERR_FIXED_LEN); + set_error_once(error_flag, ERR_FIXED_LEN); valid[row] = false; return; } @@ -2439,7 +2436,7 @@ __global__ void check_required_fields_kernel( for (int f = 0; f < num_fields; f++) { if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { // Required field is missing - set error flag - atomicExch(error_flag, ERR_REQUIRED); + set_error_once(error_flag, ERR_REQUIRED); return; // No need to check other fields for this row } } From 46ef3b0cf9c40c54aae6598d2ab509736ff1122a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 28 Feb 2026 12:31:10 +0800 Subject: [PATCH 030/107] bug fixs Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 480 +++++++++++++++++++++++++++++++++-- 1 file changed, 463 insertions(+), 17 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index e43223b67a..2744a70ee3 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -2352,6 +2352,33 @@ std::unique_ptr make_empty_column_safe(cudf::data_type dtype, } } +/** + * Create an all-null LIST column with the provided child column. + * The child column is expected to have 0 rows. + */ +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + num_rows, + std::move(null_mask), + stream, + mr); +} + /** * Find all child field indices for a given parent index in the schema. * This is a commonly used pattern throughout the codebase. @@ -2399,13 +2426,32 @@ std::unique_ptr make_empty_struct_column_with_schema( for (int child_idx : child_indices) { auto child_type = schema_output_types[child_idx]; - // Recursively handle nested struct children + std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { - children.push_back(make_empty_struct_column_with_schema( - schema, schema_output_types, child_idx, num_fields, stream, mr)); + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_idx, num_fields, stream, mr); } else { - children.push_back(make_empty_column_safe(child_type, stream, mr)); + child_col = make_empty_column_safe(child_type, stream, mr); + } + + if (schema[child_idx].is_repeated) { + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + child_col = cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); } + + children.push_back(std::move(child_col)); } return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); @@ -3125,13 +3171,30 @@ std::unique_ptr build_repeated_struct_column( std::vector> empty_struct_children; for (int child_schema_idx : child_field_indices) { auto child_type = schema_output_types[child_schema_idx]; + std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { - // Use helper to recursively build nested struct - empty_struct_children.push_back(make_empty_struct_column_with_schema( - h_device_schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr)); + child_col = make_empty_struct_column_with_schema( + h_device_schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr); } else { - empty_struct_children.push_back(make_empty_column_safe(child_type, stream, mr)); + child_col = make_empty_column_safe(child_type, stream, mr); + } + if (h_device_schema[child_schema_idx].is_repeated) { + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + child_col = cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); } + empty_struct_children.push_back(std::move(child_col)); } auto empty_struct = cudf::make_structs_column( 0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); @@ -3195,6 +3258,7 @@ std::unique_ptr build_repeated_struct_column( // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) rmm::device_uvector d_msg_locs(total_count, stream, mr); rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); + rmm::device_uvector d_msg_row_offsets_size(total_count, stream, mr); { auto const occ_threads = THREADS_PER_BLOCK; auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; @@ -3206,6 +3270,11 @@ std::unique_ptr build_repeated_struct_column( d_msg_row_offsets.data(), total_count); } + thrust::transform(rmm::exec_policy(stream), + d_msg_row_offsets.data(), + d_msg_row_offsets.end(), + d_msg_row_offsets_size.data(), + [] __device__(int32_t v) { return static_cast(v); }); // Scan for child fields within each message occurrence rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); @@ -3234,11 +3303,356 @@ std::unique_ptr build_repeated_struct_column( // Extract child field values - build one column per child field std::vector> struct_children; + int num_schema_fields = static_cast(h_device_schema.size()); for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; auto const dt = schema_output_types[child_schema_idx]; auto const enc = h_device_schema[child_schema_idx].encoding; bool has_def = h_device_schema[child_schema_idx].has_default_value; + bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; + + if (child_is_repeated) { + auto const elem_type_id = schema[child_schema_idx].output_type; + rmm::device_uvector d_rep_info(total_count, stream, mr); + + std::vector rep_indices = {0}; + rmm::device_uvector d_rep_indices(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), + rep_indices.data(), + sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + + device_nested_field_descriptor rep_desc; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + rep_desc.parent_idx = -1; + rep_desc.depth = 0; + rep_desc.encoding = 0; + rep_desc.is_required = false; + rep_desc.has_default_value = false; + + std::vector h_rep_schema = {rep_desc}; + rmm::device_uvector d_rep_schema(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), + h_rep_schema.data(), + sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + count_repeated_in_nested_kernel<<>>(message_data, + d_msg_row_offsets_size + .data(), + 0, + d_msg_locs.data(), + total_count, + d_rep_schema.data(), + 1, + d_rep_info.data(), + 1, + d_rep_indices.data(), + d_error.data()); + + rmm::device_uvector d_rep_counts(total_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.data(), + d_rep_info.end(), + d_rep_counts.data(), + [] __device__(repeated_field_info const& info) { return info.count; }); + int total_rep_count = + thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); + + if (total_rep_count == 0) { + rmm::device_uvector list_offsets_vec(total_count + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); + auto list_offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + list_offsets_vec.release(), + rmm::device_buffer{}, + 0); + std::unique_ptr child_col; + if (elem_type_id == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr); + } else { + child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + struct_children.push_back(cudf::make_lists_column(total_count, + std::move(list_offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr)); + } else { + rmm::device_uvector list_offs(total_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_rep_counts.data(), + d_rep_counts.end(), + list_offs.begin(), + 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + total_count, + &total_rep_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + scan_repeated_in_nested_kernel<<>>( + message_data, + d_msg_row_offsets_size.data(), + 0, + d_msg_locs.data(), + total_count, + d_rep_schema.data(), + 1, + list_offs.data(), + 1, + d_rep_indices.data(), + d_rep_occs.data(), + d_error.data()); + + std::unique_ptr child_values; + if (elem_type_id == cudf::type_id::INT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + extract_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::INT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + extract_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::BOOL8) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + extract_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::FLOAT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + extract_fixed_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::FLOAT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + extract_fixed_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::STRING) { + rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.data(), + d_rep_occs.end(), + d_str_lengths.data(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); + + int32_t total_chars = + thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); + rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_str_lengths.data(), + d_str_lengths.end(), + str_offs.data(), + 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + copy_varlen_data_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + loc_provider, + total_rep_count, + str_offs.data(), + chars.data(), + d_error.data()); + } + + auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count + 1, + str_offs.release(), + rmm::device_buffer{}, + 0); + child_values = cudf::make_strings_column( + total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + } else if (elem_type_id == cudf::type_id::LIST) { + rmm::device_uvector d_len(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.data(), + d_rep_occs.end(), + d_len.data(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); + + int32_t total_bytes = + thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); + rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector bytes(total_bytes, stream, mr); + if (total_bytes > 0) { + NestedRepeatedLocationProvider loc_provider{ + d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; + copy_varlen_data_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + loc_provider, + total_rep_count, + byte_offs.data(), + bytes.data(), + d_error.data()); + } + + auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count + 1, + byte_offs.release(), + rmm::device_buffer{}, + 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(bytes.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); + child_values = cudf::make_lists_column(total_rep_count, + std::move(offs_col), + std::move(bytes_child), + 0, + rmm::device_buffer{}, + stream, + mr); + } else if (elem_type_id == cudf::type_id::STRUCT) { + auto gc_indices = find_child_field_indices(schema, num_schema_fields, child_schema_idx); + if (gc_indices.empty()) { + child_values = cudf::make_structs_column(total_rep_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr); + } else { + rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); + rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); + auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + compute_virtual_parents_for_nested_repeated_kernel<<>>( + d_rep_occs.data(), + d_msg_row_offsets_size.data(), + d_msg_locs.data(), + d_virtual_row_offsets.data(), + d_virtual_parent_locs.data(), + total_rep_count); + + child_values = build_nested_struct_column(message_data, + d_virtual_row_offsets.data(), + 0, + d_virtual_parent_locs, + gc_indices, + schema, + num_schema_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error_top, + total_rep_count, + stream, + mr, + 1); + } + } else { + child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_count + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + struct_children.push_back(cudf::make_lists_column(total_count, + std::move(list_offs_col), + std::move(child_values), + 0, + rmm::device_buffer{}, + stream, + mr)); + } + continue; + } switch (dt.id()) { case cudf::type_id::BOOL8: { @@ -3523,12 +3937,30 @@ std::unique_ptr build_nested_struct_column( std::vector> empty_children; for (int child_schema_idx : child_field_indices) { auto child_type = schema_output_types[child_schema_idx]; + std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { - empty_children.push_back(make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_fields, stream, mr)); + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_fields, stream, mr); } else { - empty_children.push_back(make_empty_column_safe(child_type, stream, mr)); + child_col = make_empty_column_safe(child_type, stream, mr); } + if (schema[child_schema_idx].is_repeated) { + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + child_col = cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + } + empty_children.push_back(std::move(child_col)); } return cudf::make_structs_column( 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); @@ -5064,8 +5496,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Repeated message field - ArrayType(StructType) auto child_field_indices = find_child_field_indices(schema, num_fields, schema_idx); if (child_field_indices.empty()) { - // No child fields - create null column - column_map[schema_idx] = make_null_column(element_type, num_rows, stream, mr); + auto empty_struct_child = make_empty_struct_column_with_schema( + schema, schema_output_types, schema_idx, num_fields, stream, mr); + column_map[schema_idx] = make_null_list_column_with_child( + std::move(empty_struct_child), num_rows, stream, mr); } else { column_map[schema_idx] = build_repeated_struct_column(binary_input, h_device_schema[schema_idx], @@ -5092,7 +5526,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } default: // Unsupported element type - create null column - column_map[schema_idx] = make_null_column(element_type, num_rows, stream, mr); + column_map[schema_idx] = make_null_list_column_with_child( + make_empty_column_safe(element_type, stream, mr), num_rows, stream, mr); break; } } else { @@ -5215,9 +5650,20 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& if (it != column_map.end()) { top_level_children.push_back(std::move(it->second)); } else { - // Field not processed - create null column - top_level_children.push_back( - make_null_column(schema_output_types[i], num_rows, stream, mr)); + if (schema[i].is_repeated) { + auto const element_type = schema_output_types[i]; + std::unique_ptr empty_child; + if (element_type.id() == cudf::type_id::STRUCT) { + empty_child = make_empty_struct_column_with_schema( + schema, schema_output_types, i, num_fields, stream, mr); + } else { + empty_child = make_empty_column_safe(element_type, stream, mr); + } + top_level_children.push_back(make_null_list_column_with_child( + std::move(empty_child), num_rows, stream, mr)); + } else { + top_level_children.push_back(make_null_column(schema_output_types[i], num_rows, stream, mr)); + } } } } From 5438a17a0133f855248a3a7c99120c430edb245b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 28 Feb 2026 13:29:00 +0800 Subject: [PATCH 031/107] clean up code Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 1145 +++++++++++++--------------------- 1 file changed, 428 insertions(+), 717 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 2744a70ee3..85772cfacc 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -2379,6 +2379,30 @@ std::unique_ptr make_null_list_column_with_child( mr); } +/** + * Wrap a 0-row element column into a 0-row LIST column. + */ +std::unique_ptr make_empty_list_column( + std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(element_col), 0, rmm::device_buffer{}, stream, mr); +} + /** * Find all child field indices for a given parent index in the schema. * This is a commonly used pattern throughout the codebase. @@ -2435,20 +2459,7 @@ std::unique_ptr make_empty_struct_column_with_schema( } if (schema[child_idx].is_repeated) { - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - child_col = cudf::make_lists_column( - 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + child_col = make_empty_list_column(std::move(child_col), stream, mr); } children.push_back(std::move(child_col)); @@ -3125,6 +3136,30 @@ std::unique_ptr build_nested_struct_column( rmm::device_async_resource_ref mr, int depth); +// Forward declaration -- build_repeated_child_list_column is defined after build_nested_struct_column +// but both build_repeated_struct_column and build_nested_struct_column need to call it. +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth); + /** * Build a repeated struct column (LIST of STRUCT). * This handles repeated message fields like: repeated Item items = 2; @@ -3179,20 +3214,7 @@ std::unique_ptr build_repeated_struct_column( child_col = make_empty_column_safe(child_type, stream, mr); } if (h_device_schema[child_schema_idx].is_repeated) { - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - child_col = cudf::make_lists_column( - 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + child_col = make_empty_list_column(std::move(child_col), stream, mr); } empty_struct_children.push_back(std::move(child_col)); } @@ -3312,345 +3334,12 @@ std::unique_ptr build_repeated_struct_column( bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; if (child_is_repeated) { - auto const elem_type_id = schema[child_schema_idx].output_type; - rmm::device_uvector d_rep_info(total_count, stream, mr); - - std::vector rep_indices = {0}; - rmm::device_uvector d_rep_indices(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), - rep_indices.data(), - sizeof(int), - cudaMemcpyHostToDevice, - stream.value())); - - device_nested_field_descriptor rep_desc; - rep_desc.field_number = schema[child_schema_idx].field_number; - rep_desc.wire_type = schema[child_schema_idx].wire_type; - rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); - rep_desc.is_repeated = true; - rep_desc.parent_idx = -1; - rep_desc.depth = 0; - rep_desc.encoding = 0; - rep_desc.is_required = false; - rep_desc.has_default_value = false; - - std::vector h_rep_schema = {rep_desc}; - rmm::device_uvector d_rep_schema(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), - h_rep_schema.data(), - sizeof(device_nested_field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); - - count_repeated_in_nested_kernel<<>>(message_data, - d_msg_row_offsets_size - .data(), - 0, - d_msg_locs.data(), - total_count, - d_rep_schema.data(), - 1, - d_rep_info.data(), - 1, - d_rep_indices.data(), - d_error.data()); - - rmm::device_uvector d_rep_counts(total_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.data(), - d_rep_info.end(), - d_rep_counts.data(), - [] __device__(repeated_field_info const& info) { return info.count; }); - int total_rep_count = - thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); - - if (total_rep_count == 0) { - rmm::device_uvector list_offsets_vec(total_count + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); - auto list_offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - list_offsets_vec.release(), - rmm::device_buffer{}, - 0); - std::unique_ptr child_col; - if (elem_type_id == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr); - } else { - child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - struct_children.push_back(cudf::make_lists_column(total_count, - std::move(list_offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr)); - } else { - rmm::device_uvector list_offs(total_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_rep_counts.data(), - d_rep_counts.end(), - list_offs.begin(), - 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + total_count, - &total_rep_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>( - message_data, - d_msg_row_offsets_size.data(), - 0, - d_msg_locs.data(), - total_count, - d_rep_schema.data(), - 1, - list_offs.data(), - 1, - d_rep_indices.data(), - d_rep_occs.data(), - d_error.data()); - - std::unique_ptr child_values; - if (elem_type_id == cudf::type_id::INT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::INT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::BOOL8) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::FLOAT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - extract_fixed_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::FLOAT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - extract_fixed_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::STRING) { - rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.data(), - d_rep_occs.end(), - d_str_lengths.data(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - int32_t total_chars = - thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); - rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_str_lengths.data(), - d_str_lengths.end(), - str_offs.data(), - 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - copy_varlen_data_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - loc_provider, - total_rep_count, - str_offs.data(), - chars.data(), - d_error.data()); - } - - auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count + 1, - str_offs.release(), - rmm::device_buffer{}, - 0); - child_values = cudf::make_strings_column( - total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); - } else if (elem_type_id == cudf::type_id::LIST) { - rmm::device_uvector d_len(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.data(), - d_rep_occs.end(), - d_len.data(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - int32_t total_bytes = - thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); - rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector bytes(total_bytes, stream, mr); - if (total_bytes > 0) { - NestedRepeatedLocationProvider loc_provider{ - d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), d_rep_occs.data()}; - copy_varlen_data_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - loc_provider, - total_rep_count, - byte_offs.data(), - bytes.data(), - d_error.data()); - } - - auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count + 1, - byte_offs.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(bytes.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); - child_values = cudf::make_lists_column(total_rep_count, - std::move(offs_col), - std::move(bytes_child), - 0, - rmm::device_buffer{}, - stream, - mr); - } else if (elem_type_id == cudf::type_id::STRUCT) { - auto gc_indices = find_child_field_indices(schema, num_schema_fields, child_schema_idx); - if (gc_indices.empty()) { - child_values = cudf::make_structs_column(total_rep_count, - std::vector>{}, - 0, - rmm::device_buffer{}, - stream, - mr); - } else { - rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); - rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); - auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - compute_virtual_parents_for_nested_repeated_kernel<<>>( - d_rep_occs.data(), - d_msg_row_offsets_size.data(), - d_msg_locs.data(), - d_virtual_row_offsets.data(), - d_virtual_parent_locs.data(), - total_rep_count); - - child_values = build_nested_struct_column(message_data, - d_virtual_row_offsets.data(), - 0, - d_virtual_parent_locs, - gc_indices, - schema, - num_schema_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error_top, - total_rep_count, - stream, - mr, - 1); - } - } else { - child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - - auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - struct_children.push_back(cudf::make_lists_column(total_count, - std::move(list_offs_col), - std::move(child_values), - 0, - rmm::device_buffer{}, - stream, - mr)); - } + struct_children.push_back(build_repeated_child_list_column( + message_data, d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), total_count, + child_schema_idx, schema, num_schema_fields, schema_output_types, + default_ints, default_floats, default_bools, default_strings, + enum_valid_values, enum_names, d_row_has_invalid_enum, d_error_top, + stream, mr, 1)); continue; } @@ -3945,20 +3634,7 @@ std::unique_ptr build_nested_struct_column( child_col = make_empty_column_safe(child_type, stream, mr); } if (schema[child_schema_idx].is_repeated) { - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - child_col = cudf::make_lists_column( - 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + child_col = make_empty_list_column(std::move(child_col), stream, mr); } empty_children.push_back(std::move(child_col)); } @@ -4006,342 +3682,12 @@ std::unique_ptr build_nested_struct_column( bool is_repeated = schema[child_schema_idx].is_repeated; if (is_repeated) { - auto elem_type_id = schema[child_schema_idx].output_type; - rmm::device_uvector d_rep_info(num_rows, stream, mr); - - std::vector rep_indices = {0}; - rmm::device_uvector d_rep_indices(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), - rep_indices.data(), - sizeof(int), - cudaMemcpyHostToDevice, - stream.value())); - - device_nested_field_descriptor rep_desc; - rep_desc.field_number = schema[child_schema_idx].field_number; - rep_desc.wire_type = schema[child_schema_idx].wire_type; - rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); - rep_desc.is_repeated = true; - rep_desc.parent_idx = -1; - rep_desc.depth = 0; - rep_desc.encoding = 0; - rep_desc.is_required = false; - rep_desc.has_default_value = false; - - std::vector h_rep_schema = {rep_desc}; - rmm::device_uvector d_rep_schema(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), - h_rep_schema.data(), - sizeof(device_nested_field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); - - count_repeated_in_nested_kernel<<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - num_rows, - d_rep_schema.data(), - 1, - d_rep_info.data(), - 1, - d_rep_indices.data(), - d_error.data()); - - rmm::device_uvector d_rep_counts(num_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.data(), - d_rep_info.end(), - d_rep_counts.data(), - [] __device__(repeated_field_info const& info) { return info.count; }); - int total_rep_count = - thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); - - if (total_rep_count == 0) { - rmm::device_uvector list_offsets_vec(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); - auto list_offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offsets_vec.release(), - rmm::device_buffer{}, - 0); - std::unique_ptr child_col; - if (elem_type_id == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_fields, stream, mr); - } else { - child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - struct_children.push_back(cudf::make_lists_column(num_rows, - std::move(list_offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr)); - } else { - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_rep_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - num_rows, - d_rep_schema.data(), - 1, - list_offs.data(), - 1, - d_rep_indices.data(), - d_rep_occs.data(), - d_error.data()); - - std::unique_ptr child_values; - if (elem_type_id == cudf::type_id::INT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::INT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::BOOL8) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::FLOAT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - extract_fixed_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::FLOAT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - extract_fixed_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::STRING) { - rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.data(), - d_rep_occs.end(), - d_str_lengths.data(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - int32_t total_chars = - thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); - rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_str_lengths.data(), - d_str_lengths.end(), - str_offs.data(), - 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - copy_varlen_data_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - loc_provider, - total_rep_count, - str_offs.data(), - chars.data(), - d_error.data()); - } - - auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count + 1, - str_offs.release(), - rmm::device_buffer{}, - 0); - child_values = cudf::make_strings_column( - total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); - } else if (elem_type_id == cudf::type_id::LIST) { - rmm::device_uvector d_len(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.data(), - d_rep_occs.end(), - d_len.data(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - int32_t total_bytes = - thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); - rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector bytes(total_bytes, stream, mr); - if (total_bytes > 0) { - NestedRepeatedLocationProvider loc_provider{ - list_offsets, base_offset, d_parent_locs.data(), d_rep_occs.data()}; - copy_varlen_data_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - loc_provider, - total_rep_count, - byte_offs.data(), - bytes.data(), - d_error.data()); - } - - auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count + 1, - byte_offs.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(bytes.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); - child_values = cudf::make_lists_column(total_rep_count, - std::move(offs_col), - std::move(bytes_child), - 0, - rmm::device_buffer{}, - stream, - mr); - } else if (elem_type_id == cudf::type_id::STRUCT) { - // Repeated message field (ArrayType(StructType)) inside nested message. - // Build virtual parent info for each occurrence so we can recursively decode children. - auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); - if (gc_indices.empty()) { - child_values = cudf::make_structs_column(total_rep_count, - std::vector>{}, - 0, - rmm::device_buffer{}, - stream, - mr); - } else { - rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); - rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); - auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - compute_virtual_parents_for_nested_repeated_kernel<<>>( - d_rep_occs.data(), - list_offsets, - d_parent_locs.data(), - d_virtual_row_offsets.data(), - d_virtual_parent_locs.data(), - total_rep_count); - - child_values = build_nested_struct_column(message_data, - d_virtual_row_offsets.data(), - base_offset, - d_virtual_parent_locs, - gc_indices, - schema, - num_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - total_rep_count, - stream, - mr, - depth + 1); - } - } else { - child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - - auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - struct_children.push_back(cudf::make_lists_column(num_rows, - std::move(list_offs_col), - std::move(child_values), - 0, - rmm::device_buffer{}, - stream, - mr)); - } + struct_children.push_back(build_repeated_child_list_column( + message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, + child_schema_idx, schema, num_fields, schema_output_types, + default_ints, default_floats, default_bools, default_strings, + enum_valid_values, enum_names, d_row_has_invalid_enum, d_error, + stream, mr, depth)); continue; } @@ -4676,6 +4022,371 @@ std::unique_ptr build_nested_struct_column( num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); } +/** + * Build a LIST column for a repeated child field inside a parent message. + * Shared between build_nested_struct_column and build_repeated_struct_column. + */ +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_parent_rows + threads - 1) / threads); + + auto elem_type_id = schema[child_schema_idx].output_type; + rmm::device_uvector d_rep_info(num_parent_rows, stream, mr); + + std::vector rep_indices = {0}; + rmm::device_uvector d_rep_indices(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), + rep_indices.data(), + sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + + device_nested_field_descriptor rep_desc; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + rep_desc.parent_idx = -1; + rep_desc.depth = 0; + rep_desc.encoding = 0; + rep_desc.is_required = false; + rep_desc.has_default_value = false; + + std::vector h_rep_schema = {rep_desc}; + rmm::device_uvector d_rep_schema(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), + h_rep_schema.data(), + sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + count_repeated_in_nested_kernel<<>>(message_data, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + 1, + d_rep_info.data(), + 1, + d_rep_indices.data(), + d_error.data()); + + rmm::device_uvector d_rep_counts(num_parent_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.data(), + d_rep_info.end(), + d_rep_counts.data(), + [] __device__(repeated_field_info const& info) { return info.count; }); + int total_rep_count = + thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); + + if (total_rep_count == 0) { + rmm::device_uvector list_offsets_vec(num_parent_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); + auto list_offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offsets_vec.release(), + rmm::device_buffer{}, + 0); + std::unique_ptr child_col; + if (elem_type_id == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + return cudf::make_lists_column(num_parent_rows, + std::move(list_offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); + } + + rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_parent_rows, + &total_rep_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + scan_repeated_in_nested_kernel<<>>(message_data, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + 1, + list_offs.data(), + 1, + d_rep_indices.data(), + d_rep_occs.data(), + d_error.data()); + + std::unique_ptr child_values; + if (elem_type_id == cudf::type_id::INT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + extract_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::INT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + extract_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::BOOL8) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + extract_varint_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::FLOAT32) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + extract_fixed_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::FLOAT64) { + rmm::device_uvector values(total_rep_count, stream, mr); + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + extract_fixed_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>( + message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); + child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, + total_rep_count, + values.release(), + rmm::device_buffer{}, + 0); + } else if (elem_type_id == cudf::type_id::STRING) { + rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.data(), + d_rep_occs.end(), + d_str_lengths.data(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); + + int32_t total_chars = + thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); + rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy(stream), + d_str_lengths.data(), + d_str_lengths.end(), + str_offs.data(), + 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, + &total_chars, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + copy_varlen_data_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + loc_provider, + total_rep_count, + str_offs.data(), + chars.data(), + d_error.data()); + } + + auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count + 1, + str_offs.release(), + rmm::device_buffer{}, + 0); + child_values = cudf::make_strings_column( + total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + } else if (elem_type_id == cudf::type_id::LIST) { + rmm::device_uvector d_len(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.data(), + d_rep_occs.end(), + d_len.data(), + [] __device__(repeated_occurrence const& occ) { return occ.length; }); + + int32_t total_bytes = + thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); + rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, + &total_bytes, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector bytes(total_bytes, stream, mr); + if (total_bytes > 0) { + NestedRepeatedLocationProvider loc_provider{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + copy_varlen_data_kernel + <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, + THREADS_PER_BLOCK, + 0, + stream.value()>>>(message_data, + loc_provider, + total_rep_count, + byte_offs.data(), + bytes.data(), + d_error.data()); + } + + auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + total_rep_count + 1, + byte_offs.release(), + rmm::device_buffer{}, + 0); + auto bytes_child = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(bytes.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); + child_values = cudf::make_lists_column(total_rep_count, + std::move(offs_col), + std::move(bytes_child), + 0, + rmm::device_buffer{}, + stream, + mr); + } else if (elem_type_id == cudf::type_id::STRUCT) { + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + child_values = cudf::make_structs_column(total_rep_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr); + } else { + rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); + rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); + auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + compute_virtual_parents_for_nested_repeated_kernel<<>>( + d_rep_occs.data(), + row_offsets, + parent_locs, + d_virtual_row_offsets.data(), + d_virtual_parent_locs.data(), + total_rep_count); + + child_values = build_nested_struct_column(message_data, + d_virtual_row_offsets.data(), + base_offset, + d_virtual_parent_locs, + gc_indices, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + total_rep_count, + stream, + mr, + depth + 1); + } + } else { + child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column(num_parent_rows, + std::move(list_offs_col), + std::move(child_values), + 0, + rmm::device_buffer{}, + stream, + mr); +} + } // anonymous namespace std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, From c009f62ae24739344cc52fbcc86e49545165b351 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 28 Feb 2026 13:29:45 +0800 Subject: [PATCH 032/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 139 ++++++++++++++++++++--------------- 1 file changed, 79 insertions(+), 60 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 85772cfacc..456a6f6424 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -216,8 +216,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; } - case WT_EGROUP: - return 0; + case WT_EGROUP: return 0; default: return -1; } } @@ -2369,7 +2368,7 @@ std::unique_ptr make_null_list_column_with_child( offsets.release(), rmm::device_buffer{}, 0); - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); return cudf::make_lists_column(num_rows, std::move(offsets_col), std::move(child_col), @@ -2382,18 +2381,16 @@ std::unique_ptr make_null_list_column_with_child( /** * Wrap a 0-row element column into a 0-row LIST column. */ -std::unique_ptr make_empty_list_column( - std::unique_ptr element_col, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - int32_t zero = 0; + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + int32_t zero = 0; CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), &zero, sizeof(int32_t), @@ -3136,8 +3133,9 @@ std::unique_ptr build_nested_struct_column( rmm::device_async_resource_ref mr, int depth); -// Forward declaration -- build_repeated_child_list_column is defined after build_nested_struct_column -// but both build_repeated_struct_column and build_nested_struct_column need to call it. +// Forward declaration -- build_repeated_child_list_column is defined after +// build_nested_struct_column but both build_repeated_struct_column and build_nested_struct_column +// need to call it. std::unique_ptr build_repeated_child_list_column( uint8_t const* message_data, cudf::size_type const* row_offsets, @@ -3327,19 +3325,33 @@ std::unique_ptr build_repeated_struct_column( std::vector> struct_children; int num_schema_fields = static_cast(h_device_schema.size()); for (int ci = 0; ci < num_child_fields; ci++) { - int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; - auto const enc = h_device_schema[child_schema_idx].encoding; - bool has_def = h_device_schema[child_schema_idx].has_default_value; + int child_schema_idx = child_field_indices[ci]; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = h_device_schema[child_schema_idx].encoding; + bool has_def = h_device_schema[child_schema_idx].has_default_value; bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; if (child_is_repeated) { - struct_children.push_back(build_repeated_child_list_column( - message_data, d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), total_count, - child_schema_idx, schema, num_schema_fields, schema_output_types, - default_ints, default_floats, default_bools, default_strings, - enum_valid_values, enum_names, d_row_has_invalid_enum, d_error_top, - stream, mr, 1)); + struct_children.push_back(build_repeated_child_list_column(message_data, + d_msg_row_offsets_size.data(), + 0, + d_msg_locs.data(), + total_count, + child_schema_idx, + schema, + num_schema_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error_top, + stream, + mr, + 1)); continue; } @@ -3682,12 +3694,26 @@ std::unique_ptr build_nested_struct_column( bool is_repeated = schema[child_schema_idx].is_repeated; if (is_repeated) { - struct_children.push_back(build_repeated_child_list_column( - message_data, list_offsets, base_offset, d_parent_locs.data(), num_rows, - child_schema_idx, schema, num_fields, schema_output_types, - default_ints, default_floats, default_bools, default_strings, - enum_valid_values, enum_names, d_row_has_invalid_enum, d_error, - stream, mr, depth)); + struct_children.push_back(build_repeated_child_list_column(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + child_schema_idx, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr, + depth)); continue; } @@ -4056,11 +4082,8 @@ std::unique_ptr build_repeated_child_list_column( std::vector rep_indices = {0}; rmm::device_uvector d_rep_indices(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), - rep_indices.data(), - sizeof(int), - cudaMemcpyHostToDevice, - stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_rep_indices.data(), rep_indices.data(), sizeof(int), cudaMemcpyHostToDevice, stream.value())); device_nested_field_descriptor rep_desc; rep_desc.field_number = schema[child_schema_idx].field_number; @@ -4105,12 +4128,11 @@ std::unique_ptr build_repeated_child_list_column( if (total_rep_count == 0) { rmm::device_uvector list_offsets_vec(num_parent_rows + 1, stream, mr); thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); - auto list_offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_parent_rows + 1, - list_offsets_vec.release(), - rmm::device_buffer{}, - 0); + auto list_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offsets_vec.release(), + rmm::device_buffer{}, + 0); std::unique_ptr child_col; if (elem_type_id == cudf::type_id::STRUCT) { child_col = make_empty_struct_column_with_schema( @@ -4237,11 +4259,8 @@ std::unique_ptr build_repeated_child_list_column( int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_str_lengths.data(), - d_str_lengths.end(), - str_offs.data(), - 0); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), str_offs.data(), 0); CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, &total_chars, sizeof(int32_t), @@ -4279,8 +4298,7 @@ std::unique_ptr build_repeated_child_list_column( d_len.data(), [] __device__(repeated_occurrence const& occ) { return occ.length; }); - int32_t total_bytes = - thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); + int32_t total_bytes = thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); @@ -4306,17 +4324,17 @@ std::unique_ptr build_repeated_child_list_column( d_error.data()); } - auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, total_rep_count + 1, byte_offs.release(), rmm::device_buffer{}, 0); - auto bytes_child = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(bytes.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + rmm::device_buffer(bytes.data(), total_bytes, stream, mr), + rmm::device_buffer{}, + 0); child_values = cudf::make_lists_column(total_rep_count, std::move(offs_col), std::move(bytes_child), @@ -5370,10 +5388,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { empty_child = make_empty_column_safe(element_type, stream, mr); } - top_level_children.push_back(make_null_list_column_with_child( - std::move(empty_child), num_rows, stream, mr)); + top_level_children.push_back( + make_null_list_column_with_child(std::move(empty_child), num_rows, stream, mr)); } else { - top_level_children.push_back(make_null_column(schema_output_types[i], num_rows, stream, mr)); + top_level_children.push_back( + make_null_column(schema_output_types[i], num_rows, stream, mr)); } } } From 6152733d6137b8c16dee3393d11929bc5a62bc5f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 28 Feb 2026 16:28:07 +0800 Subject: [PATCH 033/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 975 ++++++++---------- .../com/nvidia/spark/rapids/jni/Protobuf.java | 109 +- .../rapids/jni/ProtobufSchemaDescriptor.java | 117 +++ .../nvidia/spark/rapids/jni/ProtobufTest.java | 308 +++++- 4 files changed, 872 insertions(+), 637 deletions(-) create mode 100644 src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 456a6f6424..a254816f32 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -73,6 +73,10 @@ constexpr int ERR_REQUIRED = 9; // Maximum supported nesting depth for recursive struct decoding. constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; +// Threshold for using a direct-mapped lookup table for field_number -> field_index. +// Field numbers above this threshold fall back to linear search. +constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; + /** * Structure to record field location within a message. * offset < 0 means field was not found. @@ -332,6 +336,47 @@ __device__ inline uint64_t load_le(uint8_t const* p) return v; } +// ============================================================================ +// Field number lookup table helpers +// ============================================================================ + +/** + * Build a host-side direct-mapped lookup table: field_number -> field_index. + * Returns an empty vector if the max field number exceeds the threshold. + */ +inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) +{ + int max_fn = 0; + for (int i = 0; i < num_fields; i++) { + max_fn = std::max(max_fn, descs[i].field_number); + } + if (max_fn > FIELD_LOOKUP_TABLE_MAX) return {}; + std::vector table(max_fn + 1, -1); + for (int i = 0; i < num_fields; i++) { + table[descs[i].field_number] = i; + } + return table; +} + +/** + * O(1) lookup of field_number -> field_index using a direct-mapped table. + * Falls back to linear search when the table is empty (field numbers too large). + */ +__device__ inline int lookup_field(int field_number, + int const* lookup_table, + int lookup_table_size, + field_descriptor const* field_descs, + int num_fields) +{ + if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { + return lookup_table[field_number]; + } + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == field_number) return f; + } + return -1; +} + // ============================================================================ // Pass 1: Scan all fields kernel - records (offset, length) for each field // ============================================================================ @@ -342,13 +387,13 @@ __device__ inline uint64_t load_le(uint8_t const* p) * * For "last one wins" semantics (protobuf standard for repeated scalars), * we continue scanning even after finding a field. - * - * @note Time complexity: O(message_length * num_fields) per row. */ __global__ void scan_all_fields_kernel( cudf::column_device_view const d_in, field_descriptor const* field_descs, // [num_fields] int num_fields, + int const* field_lookup, // direct-mapped lookup table (nullable) + int field_lookup_size, // size of lookup table (0 if null) field_location* locations, // [num_rows * num_fields] row-major int* error_flag) { @@ -356,14 +401,11 @@ __global__ void scan_all_fields_kernel( cudf::detail::lists_column_device_view in{d_in}; if (row >= in.size()) return; - // Initialize all field locations to "not found" for (int f = 0; f < num_fields; f++) { locations[row * num_fields + f] = {-1, 0}; } - if (in.nullable() && in.is_null(row)) { - return; // Null input row - all fields remain "not found" - } + if (in.nullable() && in.is_null(row)) { return; } auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); @@ -376,52 +418,45 @@ __global__ void scan_all_fields_kernel( uint8_t const* cur = bytes + start; uint8_t const* msg_end = bytes + end; - // Scan the message once, recording locations of all target fields while (cur < msg_end) { proto_tag tag; if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } int fn = tag.field_number; int wt = tag.wire_type; - // Check if this field is one we're looking for - for (int f = 0; f < num_fields; f++) { - if (field_descs[f].field_number == fn) { - // Check wire type matches - if (wt != field_descs[f].expected_wire_type) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } + int f = lookup_field(fn, field_lookup, field_lookup_size, field_descs, num_fields); + if (f >= 0) { + if (wt != field_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } - // Record the location (relative to message start) - int data_offset = static_cast(cur - bytes - start); + // Record the location (relative to message start) + int data_offset = static_cast(cur - bytes - start); - if (wt == WT_LEN) { - // For length-delimited, record offset after length prefix and the data length - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - // Record offset pointing to the actual data (after length prefix) - locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; - } else { - // For fixed-size and varint fields, record offset and compute length - int field_size = get_wire_type_size(wt, cur, msg_end); - if (field_size < 0) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - locations[row * num_fields + f] = {data_offset, field_size}; + if (wt == WT_LEN) { + // For length-delimited, record offset after length prefix and the data length + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; } - // "Last one wins" is preserved across later message tags, no need to keep scanning - // descriptors for the same tag once matched. - break; + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + // Record offset pointing to the actual data (after length prefix) + locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; + } else { + // For fixed-size and varint fields, record offset and compute length + int field_size = get_wire_type_size(wt, cur, msg_end); + if (field_size < 0) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + locations[row * num_fields + f] = {data_offset, field_size}; } } @@ -538,9 +573,17 @@ __global__ void count_repeated_fields_kernel( count++; } } else if (expected_wt == WT_32BIT) { - count = static_cast(packed_len) / 4; + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 4); } else if (expected_wt == WT_64BIT) { - count = static_cast(packed_len) / 8; + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 8); } repeated_info[row * num_repeated_fields + i].count += count; @@ -2094,6 +2137,21 @@ __global__ void compute_msg_locations_from_occurrences_kernel( msg_locs[idx] = {occ.offset, occ.length}; } +/** + * Extract a single field's locations from a 2D strided array on the GPU. + * Replaces a D2H + CPU loop + H2D pattern for nested message location extraction. + */ +__global__ void extract_strided_locations_kernel(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + parent_locs[row] = nested_locations[row * num_fields + field_idx]; +} + /** * Functor to extract count from repeated_field_info with strided access. * Used for extracting counts for a specific repeated field from 2D array. @@ -2812,12 +2870,179 @@ namespace spark_rapids_jni { namespace { +template +std::unique_ptr extract_typed_column( + cudf::data_type dt, + int encoding, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_items, + int blocks, + int threads_per_block, + bool has_default, + int64_t default_int, + double default_float, + bool default_bool, + std::vector const& default_string, + int schema_idx, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (dt.id()) { + case cudf::type_id::BOOL8: { + int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_val); + }, + stream, + mr); + } + case cudf::type_id::INT32: { + rmm::device_uvector out(num_items, stream, mr); + rmm::device_uvector valid((num_items > 0 ? num_items : 1), stream, mr); + extract_integer_into_buffers(message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + has_default, + default_int, + encoding, + true, + out.data(), + valid.data(), + d_error.data(), + stream); + if (schema_idx < static_cast(enum_valid_values.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + if (!valid_enums.empty()) { + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + validate_enum_values_kernel<<>>( + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_items); + } + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique( + dt, num_items, out.release(), std::move(mask), null_count); + } + case cudf::type_id::UINT32: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::INT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + true, + stream, + mr); + case cudf::type_id::UINT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::FLOAT32: { + float def_float_val = has_default ? static_cast(default_float) : 0.0f; + return extract_and_build_scalar_column( + dt, + num_items, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_float_val); + }, + stream, + mr); + } + case cudf::type_id::FLOAT64: { + double def_double = has_default ? default_float : 0.0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_double); + }, + stream, + mr); + } + default: return make_null_column(dt, num_items, stream, mr); + } +} + /** * Helper to build a repeated scalar column (LIST of scalar type). */ template std::unique_ptr build_repeated_scalar_column( cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, device_nested_field_descriptor const& field_desc, rmm::device_uvector const& d_field_counts, rmm::device_uvector& d_occurrences, @@ -2826,10 +3051,6 @@ std::unique_ptr build_repeated_scalar_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - // Get input column's null mask to determine which output rows should be null - // Only rows where INPUT is null should produce null output - // Rows with valid input but count=0 should produce empty array [] - cudf::lists_column_view const in_list(binary_input); auto const input_null_count = binary_input.null_count(); if (total_count == 0) { @@ -2868,26 +3089,16 @@ std::unique_ptr build_repeated_scalar_column( } } - auto const* message_data = reinterpret_cast(in_list.child().data()); - auto const* list_offsets = in_list.offsets().data(); - - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); - // Set last offset = total_count CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); - // Extract values rmm::device_uvector values(total_count, stream, mr); rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); @@ -2959,18 +3170,18 @@ std::unique_ptr build_repeated_scalar_column( */ std::unique_ptr build_repeated_string_column( cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, device_nested_field_descriptor const& field_desc, rmm::device_uvector const& d_field_counts, rmm::device_uvector& d_occurrences, int total_count, int num_rows, - bool is_bytes, // true for bytes (LIST), false for string + bool is_bytes, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - // Get input column's null mask to determine which output rows should be null - // Only rows where INPUT is null should produce null output - // Rows with valid input but count=0 should produce empty array [] auto const input_null_count = binary_input.null_count(); if (total_count == 0) { @@ -3008,20 +3219,10 @@ std::unique_ptr build_repeated_string_column( } } - cudf::lists_column_view const in_list(binary_input); - auto const* message_data = reinterpret_cast(in_list.child().data()); - auto const* list_offsets = in_list.offsets().data(); - - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); - // Set last offset = total_count CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, sizeof(int32_t), @@ -3165,14 +3366,16 @@ std::unique_ptr build_repeated_child_list_column( */ std::unique_ptr build_repeated_struct_column( cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, device_nested_field_descriptor const& field_desc, rmm::device_uvector const& d_field_counts, rmm::device_uvector& d_occurrences, int total_count, int num_rows, - // Child field information std::vector const& h_device_schema, - std::vector const& child_field_indices, // Indices of child fields in schema + std::vector const& child_field_indices, std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, @@ -3239,20 +3442,10 @@ std::unique_ptr build_repeated_struct_column( } } - cudf::lists_column_view const in_list(binary_input); - auto const* message_data = reinterpret_cast(in_list.child().data()); - auto const* list_offsets = in_list.offsets().data(); - - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); - // Set last offset = total_count (already computed on caller side) CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &total_count, sizeof(int32_t), @@ -3356,128 +3549,39 @@ std::unique_ptr build_repeated_struct_column( } switch (dt.id()) { - case cudf::type_id::BOOL8: { - int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_scalar_column( - dt, - total_count, - [&](uint8_t* out_ptr, bool* valid_ptr) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - total_count, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_val); - }, - stream, - mr)); - break; - } - case cudf::type_id::INT32: { - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_integer_column(dt, - message_data, - loc_provider, - total_count, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - true, - stream, - mr)); - break; - } - case cudf::type_id::INT64: { - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_integer_column(dt, - message_data, - loc_provider, - total_count, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - true, - stream, - mr)); - break; - } - case cudf::type_id::FLOAT32: { - float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_scalar_column( - dt, - total_count, - [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - total_count, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_float); - }, - stream, - mr)); - break; - } + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: case cudf::type_id::FLOAT64: { - double def_double = has_def ? default_floats[child_schema_idx] : 0.0; RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), 0, d_msg_locs.data(), d_child_locs.data(), ci, num_child_fields}; - struct_children.push_back(extract_and_build_scalar_column( - dt, - total_count, - [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - total_count, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_double); - }, - stream, - mr)); + struct_children.push_back( + extract_typed_column(dt, + enc, + message_data, + loc_provider, + total_count, + blocks, + threads, + has_def, + has_def ? default_ints[child_schema_idx] : 0, + has_def ? default_floats[child_schema_idx] : 0.0, + has_def ? default_bools[child_schema_idx] : false, + default_strings[child_schema_idx], + child_schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr)); break; } case cudf::type_id::STRING: { @@ -3631,7 +3735,7 @@ std::unique_ptr build_nested_struct_column( rmm::device_async_resource_ref mr, int depth) { - CUDF_EXPECTS(depth <= MAX_NESTED_STRUCT_DECODE_DEPTH, + CUDF_EXPECTS(depth < MAX_NESTED_STRUCT_DECODE_DEPTH, "Nested protobuf struct depth exceeds supported decode recursion limit"); if (num_rows == 0) { @@ -3718,174 +3822,39 @@ std::unique_ptr build_nested_struct_column( } switch (dt.id()) { - case cudf::type_id::BOOL8: { - int64_t def_val = has_def ? (default_bools[child_schema_idx] ? 1 : 0) : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_scalar_column( - dt, - num_rows, - [&](uint8_t* out_ptr, bool* valid_ptr) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_val); - }, - stream, - mr)); - break; - } - case cudf::type_id::INT32: { - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - true, - stream, - mr)); - break; - } - case cudf::type_id::UINT32: { - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - false, - stream, - mr)); - break; - } - case cudf::type_id::INT64: { - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - true, - stream, - mr)); - break; - } - case cudf::type_id::UINT64: { - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - false, - stream, - mr)); - break; - } - case cudf::type_id::FLOAT32: { - float def_float = has_def ? static_cast(default_floats[child_schema_idx]) : 0.0f; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back(extract_and_build_scalar_column( - dt, - num_rows, - [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_float); - }, - stream, - mr)); - break; - } + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: case cudf::type_id::FLOAT64: { - double def_double = has_def ? default_floats[child_schema_idx] : 0.0; NestedLocationProvider loc_provider{list_offsets, base_offset, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; - struct_children.push_back(extract_and_build_scalar_column( - dt, - num_rows, - [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_double); - }, - stream, - mr)); + struct_children.push_back( + extract_typed_column(dt, + enc, + message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + has_def ? default_ints[child_schema_idx] : 0, + has_def ? default_floats[child_schema_idx] : 0.0, + has_def ? default_bools[child_schema_idx] : false, + default_strings[child_schema_idx], + child_schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr)); break; } case cudf::type_id::STRING: { @@ -4591,8 +4560,24 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_uvector d_locations( static_cast(num_rows) * num_scalar, stream, mr); + auto h_field_lookup = build_field_lookup_table(h_field_descs.data(), num_scalar); + rmm::device_uvector d_field_lookup(h_field_lookup.size(), stream, mr); + if (!h_field_lookup.empty()) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_lookup.data(), + h_field_lookup.data(), + h_field_lookup.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + scan_all_fields_kernel<<>>( - *d_in, d_field_descs.data(), num_scalar, d_locations.data(), d_error.data()); + *d_in, + d_field_descs.data(), + num_scalar, + h_field_lookup.empty() ? nullptr : d_field_lookup.data(), + static_cast(h_field_lookup.size()), + d_locations.data(), + d_error.data()); check_error_and_throw(); // Check required fields (after scan pass) @@ -4640,170 +4625,34 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& bool has_def = schema[schema_idx].has_default_value; switch (dt.id()) { - case cudf::type_id::BOOL8: { - int64_t def_val = has_def ? (default_bools[schema_idx] ? 1 : 0) : 0; - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_and_build_scalar_column( - dt, - num_rows, - [&](uint8_t* out_ptr, bool* valid_ptr) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_val); - }, - stream, - mr); - break; - } - case cudf::type_id::INT32: { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - int64_t def_int = has_def ? default_ints[schema_idx] : 0; - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_integer_into_buffers(message_data, - loc_provider, - num_rows, - blocks, - threads, - has_def, - def_int, - enc, - true, - out.data(), - valid.data(), - d_error.data(), - stream); - // Enum validation: check if this INT32 field has valid enum values - if (schema_idx < static_cast(enum_valid_values.size())) { - auto const& valid_enums = enum_valid_values[schema_idx]; - if (!valid_enums.empty()) { - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_rows); - } - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - column_map[schema_idx] = std::make_unique( - dt, num_rows, out.release(), std::move(mask), null_count); - break; - } - case cudf::type_id::UINT32: { - int64_t def_int = has_def ? default_ints[schema_idx] : 0; - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - false, - stream, - mr); - break; - } - case cudf::type_id::INT64: { - int64_t def_int = has_def ? default_ints[schema_idx] : 0; - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - true, - stream, - mr); - break; - } - case cudf::type_id::UINT64: { - int64_t def_int = has_def ? default_ints[schema_idx] : 0; - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_rows, - blocks, - threads, - d_error, - has_def, - def_int, - enc, - false, - stream, - mr); - break; - } - case cudf::type_id::FLOAT32: { - float def_float = has_def ? static_cast(default_floats[schema_idx]) : 0.0f; - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_and_build_scalar_column( - dt, - num_rows, - [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_float); - }, - stream, - mr); - break; - } + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: case cudf::type_id::FLOAT64: { - double def_double = has_def ? default_floats[schema_idx] : 0.0; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_and_build_scalar_column( - dt, - num_rows, - [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - d_error.data(), - has_def, - def_double); - }, - stream, - mr); + column_map[schema_idx] = extract_typed_column(dt, + enc, + message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + has_def ? default_ints[schema_idx] : 0, + has_def ? default_floats[schema_idx] : 0.0, + has_def ? default_bools[schema_idx] : false, + default_strings[schema_idx], + schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr); break; } case cudf::type_id::STRING: { @@ -4910,6 +4759,12 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& if (num_repeated > 0) { cudf::lists_column_view const in_list_view(binary_input); auto const* list_offsets = in_list_view.offsets().data(); + auto const* message_data = + reinterpret_cast(in_list_view.child().data()); + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); for (int ri = 0; ri < num_repeated; ri++) { int schema_idx = repeated_field_indices[ri]; @@ -4923,8 +4778,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_field_counts.data(), extract_strided_count{d_repeated_info.data(), ri, num_repeated}); - int total_count = - thrust::reduce(rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), 0); + int64_t total_count = thrust::reduce( + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), int64_t{0}); if (total_count > 0) { // Build offsets for occurrence scanning on GPU (performance fix!) @@ -4962,6 +4817,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::INT32: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -4973,6 +4831,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::INT64: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -4984,6 +4845,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::UINT32: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -4995,6 +4859,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::UINT64: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5006,6 +4873,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::FLOAT32: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5017,6 +4887,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::FLOAT64: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5028,6 +4901,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::BOOL8: column_map[schema_idx] = build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5047,23 +4923,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto const& valid_enums = enum_valid_values[schema_idx]; auto const& name_bytes = enum_names[schema_idx]; - cudf::lists_column_view const in_lv(binary_input); - auto const* msg_data = reinterpret_cast(in_lv.child().data()); - auto const* loffs = in_lv.offsets().data(); - - cudf::size_type boff = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &boff, loffs, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); auto const rep_blocks = static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); extract_repeated_varint_kernel - <<>>(msg_data, - loffs, - boff, + <<>>(message_data, + list_offsets, + base_offset, d_occurrences.data(), total_count, enum_ints.data(), @@ -5199,6 +5066,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } else { column_map[schema_idx] = build_repeated_string_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5212,6 +5082,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } case cudf::type_id::LIST: // bytes as LIST column_map[schema_idx] = build_repeated_string_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5231,6 +5104,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& std::move(empty_struct_child), num_rows, stream, mr); } else { column_map[schema_idx] = build_repeated_struct_column(binary_input, + message_data, + list_offsets, + base_offset, h_device_schema[schema_idx], d_field_counts, d_occurrences, @@ -5305,15 +5181,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Process nested struct fields (Phase 2) if (num_nested > 0) { - // Copy nested locations to host for processing - std::vector h_nested_locations(static_cast(num_rows) * num_nested); - CUDF_CUDA_TRY(cudaMemcpyAsync(h_nested_locations.data(), - d_nested_locations.data(), - h_nested_locations.size() * sizeof(field_location), - cudaMemcpyDeviceToHost, - stream.value())); - stream.synchronize(); - cudf::lists_column_view const in_list_view(binary_input); auto const* message_data = reinterpret_cast(in_list_view.child().data()); @@ -5337,17 +5204,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& continue; } - // Prepare parent locations for this nested field + // Extract parent locations for this nested field directly on GPU rmm::device_uvector d_parent_locs(num_rows, stream, mr); - std::vector h_parent_locs(num_rows); - for (int row = 0; row < num_rows; row++) { - h_parent_locs[row] = h_nested_locations[row * num_nested + ni]; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), - h_parent_locs.data(), - num_rows * sizeof(field_location), - cudaMemcpyHostToDevice, - stream.value())); + extract_strided_locations_kernel<<>>( + d_nested_locations.data(), ni, num_nested, d_parent_locs.data(), num_rows); + column_map[parent_schema_idx] = build_nested_struct_column(message_data, list_offsets, base_offset, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index df807a0ba8..bd8f9632d0 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -63,34 +63,31 @@ public class Protobuf { private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; /** - * Decode protobuf messages into a STRUCT column using a flattened schema representation. - * - * The schema is represented as parallel arrays where nested fields have parent indices - * pointing to their containing message field. For pure scalar schemas, all fields are - * top-level (parentIndices == -1, depthLevels == 0, isRepeated == false). + * Decode protobuf messages into a STRUCT column. * * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. - * @param fieldNumbers Protobuf field numbers for all fields in the flattened schema. - * @param parentIndices Parent field index for each field (-1 for top-level fields). - * @param depthLevels Nesting depth for each field (0 for top-level). - * @param wireTypes Expected wire type for each field (WT_VARINT, WT_64BIT, WT_LEN, WT_32BIT). - * @param outputTypeIds cudf native type ids for output columns. - * @param encodings Encoding info for each field (0=default, 1=fixed, 2=zigzag, - * 3=enum-as-string). - * @param isRepeated Whether each field is a repeated field (array). - * @param isRequired Whether each field is required (proto2). - * @param hasDefaultValue Whether each field has a default value. - * @param defaultInts Default values for int/long/enum fields. - * @param defaultFloats Default values for float/double fields. - * @param defaultBools Default values for bool fields. - * @param defaultStrings Default values for string/bytes fields as UTF-8 bytes. - * @param enumValidValues Valid enum values for each field (null if not an enum). - * @param enumNames Enum value names for enum-as-string fields (null if not enum-as-string). - * For each field, this is a byte[][] containing UTF-8 enum names ordered by - * the same sorted order as enumValidValues for that field. + * @param schema descriptor containing flattened schema arrays (field numbers, types, defaults, etc.) * @param failOnErrors if true, throw an exception on malformed protobuf messages. * @return a cudf STRUCT column with nested structure. */ + public static ColumnVector decodeToStruct(ColumnView binaryInput, + ProtobufSchemaDescriptor schema, + boolean failOnErrors) { + long handle = decodeToStruct(binaryInput.getNativeView(), + schema.fieldNumbers, schema.parentIndices, schema.depthLevels, + schema.wireTypes, schema.outputTypeIds, schema.encodings, + schema.isRepeated, schema.isRequired, schema.hasDefaultValue, + schema.defaultInts, schema.defaultFloats, schema.defaultBools, + schema.defaultStrings, schema.enumValidValues, schema.enumNames, failOnErrors); + return new ColumnVector(handle); + } + + /** + * Decode protobuf messages using individual parallel arrays. + * + * @deprecated Use {@link #decodeToStruct(ColumnView, ProtobufSchemaDescriptor, boolean)} instead. + */ + @Deprecated public static ColumnVector decodeToStruct(ColumnView binaryInput, int[] fieldNumbers, int[] parentIndices, @@ -108,66 +105,20 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, int[][] enumValidValues, byte[][][] enumNames, boolean failOnErrors) { - // Parameter validation - if (fieldNumbers == null || parentIndices == null || depthLevels == null || - wireTypes == null || outputTypeIds == null || encodings == null || - isRepeated == null || isRequired == null || hasDefaultValue == null || - defaultInts == null || defaultFloats == null || defaultBools == null || - defaultStrings == null || enumValidValues == null || enumNames == null) { - throw new IllegalArgumentException("Arrays must be non-null"); - } - - int numFields = fieldNumbers.length; - if (parentIndices.length != numFields || - depthLevels.length != numFields || - wireTypes.length != numFields || - outputTypeIds.length != numFields || - encodings.length != numFields || - isRepeated.length != numFields || - isRequired.length != numFields || - hasDefaultValue.length != numFields || - defaultInts.length != numFields || - defaultFloats.length != numFields || - defaultBools.length != numFields || - defaultStrings.length != numFields || - enumValidValues.length != numFields || - enumNames.length != numFields) { - throw new IllegalArgumentException("All arrays must have the same length"); - } - - // Validate field numbers are positive and within protobuf spec range - for (int i = 0; i < fieldNumbers.length; i++) { - if (fieldNumbers[i] <= 0 || fieldNumbers[i] > MAX_FIELD_NUMBER) { - throw new IllegalArgumentException( - "Invalid field number at index " + i + ": " + fieldNumbers[i] + - " (field numbers must be 1-" + MAX_FIELD_NUMBER + ")"); - } - } - - // Validate encoding values - for (int i = 0; i < encodings.length; i++) { - int enc = encodings[i]; - if (enc < ENC_DEFAULT || enc > ENC_ENUM_STRING) { - throw new IllegalArgumentException( - "Invalid encoding value at index " + i + ": " + enc + - " (expected " + ENC_DEFAULT + ", " + ENC_FIXED + ", " + ENC_ZIGZAG + - ", or " + ENC_ENUM_STRING + ")"); - } - } - - long handle = decodeToStruct(binaryInput.getNativeView(), - fieldNumbers, parentIndices, depthLevels, - wireTypes, outputTypeIds, encodings, - isRepeated, isRequired, hasDefaultValue, - defaultInts, defaultFloats, defaultBools, - defaultStrings, enumValidValues, enumNames, failOnErrors); - return new ColumnVector(handle); + return decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, enumNames), + failOnErrors); } /** - * Backward-compatible overload for callers that don't provide enum name mappings. - * This keeps existing JNI tests and call-sites source-compatible. + * Backward-compatible overload without enum name mappings. + * + * @deprecated Use {@link #decodeToStruct(ColumnView, ProtobufSchemaDescriptor, boolean)} instead. */ + @Deprecated public static ColumnVector decodeToStruct(ColumnView binaryInput, int[] fieldNumbers, int[] parentIndices, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java new file mode 100644 index 0000000000..513ead2c37 --- /dev/null +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +/** + * Immutable descriptor for a flattened protobuf schema, grouping the parallel arrays + * that describe field structure, types, defaults, and enum metadata. + * + *

    Use this class instead of passing 15+ individual arrays through the JNI boundary. + * Validation is performed once in the constructor. + */ +public final class ProtobufSchemaDescriptor implements java.io.Serializable { + private static final long serialVersionUID = 1L; + private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; + + public final int[] fieldNumbers; + public final int[] parentIndices; + public final int[] depthLevels; + public final int[] wireTypes; + public final int[] outputTypeIds; + public final int[] encodings; + public final boolean[] isRepeated; + public final boolean[] isRequired; + public final boolean[] hasDefaultValue; + public final long[] defaultInts; + public final double[] defaultFloats; + public final boolean[] defaultBools; + public final byte[][] defaultStrings; + public final int[][] enumValidValues; + public final byte[][][] enumNames; + + /** + * @throws IllegalArgumentException if any array is null, arrays have mismatched lengths, + * field numbers are out of range, or encoding values are invalid. + */ + public ProtobufSchemaDescriptor( + int[] fieldNumbers, + int[] parentIndices, + int[] depthLevels, + int[] wireTypes, + int[] outputTypeIds, + int[] encodings, + boolean[] isRepeated, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + byte[][][] enumNames) { + + if (fieldNumbers == null || parentIndices == null || depthLevels == null || + wireTypes == null || outputTypeIds == null || encodings == null || + isRepeated == null || isRequired == null || hasDefaultValue == null || + defaultInts == null || defaultFloats == null || defaultBools == null || + defaultStrings == null || enumValidValues == null || enumNames == null) { + throw new IllegalArgumentException("All schema arrays must be non-null"); + } + + int n = fieldNumbers.length; + if (parentIndices.length != n || depthLevels.length != n || + wireTypes.length != n || outputTypeIds.length != n || + encodings.length != n || isRepeated.length != n || + isRequired.length != n || hasDefaultValue.length != n || + defaultInts.length != n || defaultFloats.length != n || + defaultBools.length != n || defaultStrings.length != n || + enumValidValues.length != n || enumNames.length != n) { + throw new IllegalArgumentException("All schema arrays must have the same length"); + } + + for (int i = 0; i < n; i++) { + if (fieldNumbers[i] <= 0 || fieldNumbers[i] > MAX_FIELD_NUMBER) { + throw new IllegalArgumentException( + "Invalid field number at index " + i + ": " + fieldNumbers[i] + + " (must be 1-" + MAX_FIELD_NUMBER + ")"); + } + int enc = encodings[i]; + if (enc < Protobuf.ENC_DEFAULT || enc > Protobuf.ENC_ENUM_STRING) { + throw new IllegalArgumentException( + "Invalid encoding at index " + i + ": " + enc); + } + } + + this.fieldNumbers = fieldNumbers; + this.parentIndices = parentIndices; + this.depthLevels = depthLevels; + this.wireTypes = wireTypes; + this.outputTypeIds = outputTypeIds; + this.encodings = encodings; + this.isRepeated = isRepeated; + this.isRequired = isRequired; + this.hasDefaultValue = hasDefaultValue; + this.defaultInts = defaultInts; + this.defaultFloats = defaultFloats; + this.defaultBools = defaultBools; + this.defaultStrings = defaultStrings; + this.enumValidValues = enumValidValues; + this.enumNames = enumNames; + } + + public int numFields() { return fieldNumbers.length; } +} diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 3ab1b8c988..6f65e71d55 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -29,6 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertEquals; +import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; @@ -113,7 +114,7 @@ private static byte[] encodeDouble(double d) { /** Create a protobuf tag (field number + wire type). */ private static byte[] tag(int fieldNumber, int wireType) { - return encodeVarint((fieldNumber << 3) | wireType); + return encodeVarint(((long) fieldNumber << 3) | wireType); } // Wire type constants @@ -2537,4 +2538,309 @@ void testRepeatedEnumAsString() { assertEquals(1, actual.getNumChildren()); } } + + // ============================================================================ + // Edge case and boundary tests + // ============================================================================ + + @Test + void testPackedFixedMisaligned() { + byte[] packedData = new byte[]{0x01, 0x02, 0x03, 0x04, 0x05}; + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedData.length)), + box(packedData)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(RuntimeException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_32BIT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + true)) { + } + }); + } + } + + @Test + void testPackedFixedMisaligned64() { + byte[] packedData = new byte[]{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09}; + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedData.length)), + box(packedData)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(RuntimeException.class, () -> { + try (ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_64BIT}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + true)) { + } + }); + } + } + + @Test + void testLargeRepeatedField() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + for (int i = 0; i < 100000; i++) { + baos.write(tag(1, WT_VARINT)); + baos.write(encodeVarint(i)); + } + Byte[] row = box(baos.toByteArray()); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.LIST, list.getType()); + } + } + } + + @Test + void testMixedPackedUnpacked() { + byte[] packedContent = concatBytes(encodeVarint(30), encodeVarint(40)); + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(10)), + box(tag(1, WT_VARINT)), box(encodeVarint(20)), + box(tag(1, WT_LEN)), box(encodeVarint(packedContent.length)), box(packedContent)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector vals = list.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostVals = vals.copyToHost()) { + assertEquals(4, vals.getRowCount()); + assertEquals(10, hostVals.getInt(0)); + assertEquals(20, hostVals.getInt(1)); + assertEquals(30, hostVals.getInt(2)); + assertEquals(40, hostVals.getInt(3)); + } + } + } + + @Test + void testLargeFieldNumber() { + int maxFieldNumber = (1 << 29) - 1; + Byte[] row = concat( + box(tag(maxFieldNumber, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{maxFieldNumber}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector child = result.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostChild = child.copyToHost()) { + assertEquals(42, hostChild.getInt(0)); + } + } + } + + private void verifyDeepNesting(int numLevels) { + int numFields = 2 * numLevels - 1; + + byte[] current = concatBytes(tag(1, WT_VARINT), encodeVarint(1)); + for (int level = numLevels - 2; level >= 0; level--) { + current = concatBytes( + tag(1, WT_VARINT), encodeVarint(1), + tag(2, WT_LEN), encodeVarint(current.length), current); + } + Byte[] row = box(current); + + int[] fieldNumbers = new int[numFields]; + int[] parentIndices = new int[numFields]; + int[] depthLevels = new int[numFields]; + int[] wireTypes = new int[numFields]; + int[] outputTypeIds = new int[numFields]; + int[] encodings = new int[numFields]; + boolean[] isRepeated = new boolean[numFields]; + boolean[] isRequired = new boolean[numFields]; + boolean[] hasDefaultValue = new boolean[numFields]; + long[] defaultInts = new long[numFields]; + double[] defaultFloats = new double[numFields]; + boolean[] defaultBools = new boolean[numFields]; + byte[][] defaultStrings = new byte[numFields][]; + int[][] enumValidValues = new int[numFields][]; + + for (int level = 0; level < numLevels; level++) { + int intIdx = 2 * level; + int parentIdx = level == 0 ? -1 : 2 * (level - 1) + 1; + + fieldNumbers[intIdx] = 1; + parentIndices[intIdx] = parentIdx; + depthLevels[intIdx] = level; + wireTypes[intIdx] = WT_VARINT; + outputTypeIds[intIdx] = DType.INT32.getTypeId().getNativeId(); + encodings[intIdx] = Protobuf.ENC_DEFAULT; + + if (level < numLevels - 1) { + int structIdx = 2 * level + 1; + fieldNumbers[structIdx] = 2; + parentIndices[structIdx] = parentIdx; + depthLevels[structIdx] = level; + wireTypes[structIdx] = WT_LEN; + outputTypeIds[structIdx] = DType.STRUCT.getTypeId().getNativeId(); + encodings[structIdx] = Protobuf.ENC_DEFAULT; + } + } + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + fieldNumbers, parentIndices, depthLevels, wireTypes, + outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + + @Test + void testDeepNesting9Levels() { + verifyDeepNesting(9); + } + + @Test + void testDeepNesting10Levels() { + verifyDeepNesting(10); + } + + @Test + void testZeroLengthNestedMessage() { + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(0))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + + @Test + void testEmptyPackedRepeated() { + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(0))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } } From 48204d5a2b27221e73a4fab554abaae4c4193314 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 2 Mar 2026 16:57:37 +0800 Subject: [PATCH 034/107] clean up and split big files Signed-off-by: Haoyang Li --- src/main/cpp/CMakeLists.txt | 2 + src/main/cpp/src/protobuf.cu | 4408 +------------------------ src/main/cpp/src/protobuf_builders.cu | 1407 ++++++++ src/main/cpp/src/protobuf_common.cuh | 1375 ++++++++ src/main/cpp/src/protobuf_kernels.cu | 1106 +++++++ 5 files changed, 3925 insertions(+), 4373 deletions(-) create mode 100644 src/main/cpp/src/protobuf_builders.cu create mode 100644 src/main/cpp/src/protobuf_common.cuh create mode 100644 src/main/cpp/src/protobuf_kernels.cu diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 3ac34e15bd..9b5ea0af2b 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -256,6 +256,8 @@ add_library( src/number_converter.cu src/parse_uri.cu src/protobuf.cu + src/protobuf_kernels.cu + src/protobuf_builders.cu src/regex_rewrite_utils.cu src/row_conversion.cu src/round_float.cu diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a254816f32..d738ac863f 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -14,4368 +14,12 @@ * limitations under the License. */ -#include "protobuf.hpp" +#include "protobuf_common.cuh" -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace { - -// Wire type constants (protobuf encoding spec) -constexpr int WT_VARINT = 0; -constexpr int WT_64BIT = 1; -constexpr int WT_LEN = 2; -constexpr int WT_SGROUP = 3; -constexpr int WT_EGROUP = 4; -constexpr int WT_32BIT = 5; - -// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. -constexpr int MAX_VARINT_BYTES = 10; - -// CUDA kernel launch configuration. -constexpr int THREADS_PER_BLOCK = 256; - -// Error codes for kernel error reporting. -constexpr int ERR_BOUNDS = 1; -constexpr int ERR_VARINT = 2; -constexpr int ERR_FIELD_NUMBER = 3; -constexpr int ERR_WIRE_TYPE = 4; -constexpr int ERR_OVERFLOW = 5; -constexpr int ERR_FIELD_SIZE = 6; -constexpr int ERR_SKIP = 7; -constexpr int ERR_FIXED_LEN = 8; -constexpr int ERR_REQUIRED = 9; - -// Maximum supported nesting depth for recursive struct decoding. -constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; - -// Threshold for using a direct-mapped lookup table for field_number -> field_index. -// Field numbers above this threshold fall back to linear search. -constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; - -/** - * Structure to record field location within a message. - * offset < 0 means field was not found. - */ -struct field_location { - int32_t offset; // Offset of field data within the message (-1 if not found) - int32_t length; // Length of field data in bytes -}; - -/** - * Field descriptor passed to the scanning kernel. - */ -struct field_descriptor { - int field_number; // Protobuf field number - int expected_wire_type; // Expected wire type for this field -}; - -/** - * Information about repeated field occurrences in a row. - */ -struct repeated_field_info { - int32_t count; // Number of occurrences in this row - int32_t total_length; // Total bytes for all occurrences (for varlen fields) -}; - -/** - * Location of a single occurrence of a repeated field. - */ -struct repeated_occurrence { - int32_t row_idx; // Which row this occurrence belongs to - int32_t offset; // Offset within the message - int32_t length; // Length of the field data -}; - -/** - * Device-side descriptor for nested schema fields. - */ -struct device_nested_field_descriptor { - int field_number; - int parent_idx; - int depth; - int wire_type; - int output_type_id; - int encoding; - bool is_repeated; - bool is_required; - bool has_default_value; - - device_nested_field_descriptor() = default; - - explicit device_nested_field_descriptor(spark_rapids_jni::nested_field_descriptor const& src) - : field_number(src.field_number), - parent_idx(src.parent_idx), - depth(src.depth), - wire_type(src.wire_type), - output_type_id(static_cast(src.output_type)), - encoding(src.encoding), - is_repeated(src.is_repeated), - is_required(src.is_required), - has_default_value(src.has_default_value) - { - } -}; - -// ============================================================================ -// Device helper functions -// ============================================================================ - -__device__ inline bool read_varint(uint8_t const* cur, - uint8_t const* end, - uint64_t& out, - int& bytes) -{ - out = 0; - bytes = 0; - int shift = 0; - // Protobuf varint uses 7 bits per byte with MSB as continuation flag. - // A 64-bit value requires at most ceil(64/7) = 10 bytes. - while (cur < end && bytes < MAX_VARINT_BYTES) { - uint8_t b = *cur++; - // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid - if (bytes == 9 && (b & 0xFE) != 0) { - return false; // Invalid: 10th byte has more than 1 significant bit - } - out |= (static_cast(b & 0x7Fu) << shift); - bytes++; - if ((b & 0x80u) == 0) { return true; } - shift += 7; - } - return false; -} - -__device__ inline void set_error_once(int* error_flag, int error_code) -{ - atomicCAS(error_flag, 0, error_code); -} - -__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) -{ - switch (wt) { - case WT_VARINT: { - // Need to scan to find the end of varint - int count = 0; - while (cur < end && count < MAX_VARINT_BYTES) { - if ((*cur++ & 0x80u) == 0) { return count + 1; } - count++; - } - return -1; // Invalid varint - } - case WT_64BIT: - // Check if there's enough data for 8 bytes - if (end - cur < 8) return -1; - return 8; - case WT_32BIT: - // Check if there's enough data for 4 bytes - if (end - cur < 4) return -1; - return 4; - case WT_LEN: { - uint64_t len; - int n; - if (!read_varint(cur, end, len, n)) return -1; - if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX - n)) - return -1; - return n + static_cast(len); - } - case WT_SGROUP: { - auto const* start = cur; - // Recursively skip until the matching end-group tag. - while (cur < end) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, end, key, key_bytes)) return -1; - cur += key_bytes; - - int inner_wt = static_cast(key & 0x7); - if (inner_wt == WT_EGROUP) { return static_cast(cur - start); } - - int inner_size = get_wire_type_size(inner_wt, cur, end); - if (inner_size < 0 || cur + inner_size > end) return -1; - cur += inner_size; - } - return -1; - } - case WT_EGROUP: return 0; - default: return -1; - } -} - -__device__ inline bool skip_field(uint8_t const* cur, - uint8_t const* end, - int wt, - uint8_t const*& out_cur) -{ - // End-group is handled by the parent group parser. - if (wt == WT_EGROUP) { - out_cur = cur; - return true; - } - - int size = get_wire_type_size(wt, cur, end); - if (size < 0) return false; - // Ensure we don't skip past the end of the buffer - if (cur + size > end) return false; - out_cur = cur + size; - return true; -} - -/** - * Get the data offset and length for a field at current position. - * Returns true on success, false on error. - */ -__device__ inline bool get_field_data_location( - uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) -{ - if (wt == WT_LEN) { - // For length-delimited, read the length prefix - uint64_t len; - int len_bytes; - if (!read_varint(cur, end, len, len_bytes)) return false; - if (len > static_cast(end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - return false; - } - data_offset = len_bytes; // offset past the length prefix - data_length = static_cast(len); - } else { - // For fixed-size and varint fields - int field_size = get_wire_type_size(wt, cur, end); - if (field_size < 0) return false; - data_offset = 0; - data_length = field_size; - } - return true; -} - -__device__ inline bool check_message_bounds(int32_t start, - int32_t end_pos, - cudf::size_type total_size, - int* error_flag) -{ - if (start < 0 || end_pos < start || end_pos > total_size) { - set_error_once(error_flag, ERR_BOUNDS); - return false; - } - return true; -} - -struct proto_tag { - int field_number; - int wire_type; -}; - -__device__ inline bool decode_tag(uint8_t const*& cur, - uint8_t const* end, - proto_tag& tag, - int* error_flag) -{ - uint64_t key; - int key_bytes; - if (!read_varint(cur, end, key, key_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return false; - } - - cur += key_bytes; - tag.field_number = static_cast(key >> 3); - tag.wire_type = static_cast(key & 0x7); - if (tag.field_number == 0) { - set_error_once(error_flag, ERR_FIELD_NUMBER); - return false; - } - return true; -} - -/** - * Load a little-endian value from unaligned memory. - * Reads bytes individually to avoid unaligned-access issues on GPU. - */ -template -__device__ inline T load_le(uint8_t const* p); - -template <> -__device__ inline uint32_t load_le(uint8_t const* p) -{ - return static_cast(p[0]) | (static_cast(p[1]) << 8) | - (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); -} - -template <> -__device__ inline uint64_t load_le(uint8_t const* p) -{ - uint64_t v = 0; -#pragma unroll - for (int i = 0; i < 8; ++i) { - v |= (static_cast(p[i]) << (8 * i)); - } - return v; -} - -// ============================================================================ -// Field number lookup table helpers -// ============================================================================ - -/** - * Build a host-side direct-mapped lookup table: field_number -> field_index. - * Returns an empty vector if the max field number exceeds the threshold. - */ -inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) -{ - int max_fn = 0; - for (int i = 0; i < num_fields; i++) { - max_fn = std::max(max_fn, descs[i].field_number); - } - if (max_fn > FIELD_LOOKUP_TABLE_MAX) return {}; - std::vector table(max_fn + 1, -1); - for (int i = 0; i < num_fields; i++) { - table[descs[i].field_number] = i; - } - return table; -} - -/** - * O(1) lookup of field_number -> field_index using a direct-mapped table. - * Falls back to linear search when the table is empty (field numbers too large). - */ -__device__ inline int lookup_field(int field_number, - int const* lookup_table, - int lookup_table_size, - field_descriptor const* field_descs, - int num_fields) -{ - if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { - return lookup_table[field_number]; - } - for (int f = 0; f < num_fields; f++) { - if (field_descs[f].field_number == field_number) return f; - } - return -1; -} - -// ============================================================================ -// Pass 1: Scan all fields kernel - records (offset, length) for each field -// ============================================================================ - -/** - * Fused scanning kernel: scans each message once and records the location - * of all requested fields. - * - * For "last one wins" semantics (protobuf standard for repeated scalars), - * we continue scanning even after finding a field. - */ -__global__ void scan_all_fields_kernel( - cudf::column_device_view const d_in, - field_descriptor const* field_descs, // [num_fields] - int num_fields, - int const* field_lookup, // direct-mapped lookup table (nullable) - int field_lookup_size, // size of lookup table (0 if null) - field_location* locations, // [num_rows * num_fields] row-major - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; - - for (int f = 0; f < num_fields; f++) { - locations[row * num_fields + f] = {-1, 0}; - } - - if (in.nullable() && in.is_null(row)) { return; } - - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); - auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; - - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - - uint8_t const* cur = bytes + start; - uint8_t const* msg_end = bytes + end; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - int f = lookup_field(fn, field_lookup, field_lookup_size, field_descs, num_fields); - if (f >= 0) { - if (wt != field_descs[f].expected_wire_type) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - // Record the location (relative to message start) - int data_offset = static_cast(cur - bytes - start); - - if (wt == WT_LEN) { - // For length-delimited, record offset after length prefix and the data length - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - // Record offset pointing to the actual data (after length prefix) - locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; - } else { - // For fixed-size and varint fields, record offset and compute length - int field_size = get_wire_type_size(wt, cur, msg_end); - if (field_size < 0) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - locations[row * num_fields + f] = {data_offset, field_size}; - } - } - - // Skip to next field - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -// ============================================================================ -// Pass 1b: Count repeated fields kernel -// ============================================================================ - -/** - * Count occurrences of repeated fields in each row. - * Also records locations of nested message fields for hierarchical processing. - * - * @note Time complexity: O(message_length * (num_repeated_fields + num_nested_fields)) per row. - */ -__global__ void count_repeated_fields_kernel( - cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int num_fields, - int depth_level, // Which depth level we're processing - repeated_field_info* repeated_info, // [num_rows * num_repeated_fields_at_this_depth] - int num_repeated_fields, // Number of repeated fields at this depth - int const* repeated_field_indices, // Indices into schema for repeated fields at this depth - field_location* - nested_locations, // Locations of nested messages for next depth [num_rows * num_nested] - int num_nested_fields, // Number of nested message fields at this depth - int const* nested_field_indices, // Indices into schema for nested message fields - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; - - // Initialize repeated counts to 0 - for (int f = 0; f < num_repeated_fields; f++) { - repeated_info[row * num_repeated_fields + f] = {0, 0}; - } - - // Initialize nested locations to not found - for (int f = 0; f < num_nested_fields; f++) { - nested_locations[row * num_nested_fields + f] = {-1, 0}; - } - - if (in.nullable() && in.is_null(row)) { return; } - - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); - auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - - uint8_t const* cur = bytes + start; - uint8_t const* msg_end = bytes + end; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - // Check repeated fields at this depth - for (int i = 0; i < num_repeated_fields; i++) { - int schema_idx = repeated_field_indices[i]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - int expected_wt = schema[schema_idx].wire_type; - - // Handle both packed and unpacked encoding for repeated fields - // Packed encoding uses wire type LEN (2) even for scalar types - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - - if (!is_packed && wt != expected_wt) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - if (is_packed) { - // Packed encoding: read length, then count elements inside - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - - // Count elements based on type - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - int count = 0; - if (expected_wt == WT_VARINT) { - // Count varints in the packed data - uint8_t const* p = packed_start; - while (p < packed_end) { - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - p += vbytes; - count++; - } - } else if (expected_wt == WT_32BIT) { - if ((packed_len % 4) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 4); - } else if (expected_wt == WT_64BIT) { - if ((packed_len % 8) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 8); - } - - repeated_info[row * num_repeated_fields + i].count += count; - repeated_info[row * num_repeated_fields + i].total_length += - static_cast(packed_len); - } else { - // Non-packed encoding: single element - int32_t data_offset, data_length; - if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - - repeated_info[row * num_repeated_fields + i].count++; - repeated_info[row * num_repeated_fields + i].total_length += data_length; - } - } - } - - // Check nested message fields at this depth (last one wins for non-repeated) - for (int i = 0; i < num_nested_fields; i++) { - int schema_idx = nested_field_indices[i]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - if (wt != WT_LEN) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - - int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; - nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; - } - } - - // Skip to next field - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -/** - * Scan and record all occurrences of repeated fields. - * Called after count_repeated_fields_kernel to fill in actual locations. - * - * @note Time complexity: O(message_length * num_repeated_fields) per row. - */ -__global__ void scan_repeated_field_occurrences_kernel( - cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int schema_idx, // Which field in schema we're scanning - int depth_level, - int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] - repeated_occurrence* occurrences, // Output: all occurrences [total_count] - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; - - if (in.nullable() && in.is_null(row)) { return; } - - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); - auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - - uint8_t const* cur = bytes + start; - uint8_t const* msg_end = bytes + end; - - int target_fn = schema[schema_idx].field_number; - int target_wt = schema[schema_idx].wire_type; - int write_idx = output_offsets[row]; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - if (fn == target_fn) { - // Check for packed encoding: wire type LEN but expected non-LEN - bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); - - if (is_packed) { - // Packed encoding: multiple elements in a length-delimited blob - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - // Record each element in the packed blob - if (target_wt == WT_VARINT) { - // Varints: parse each one - uint8_t const* p = packed_start; - while (p < packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - occurrences[write_idx] = {static_cast(row), elem_offset, vbytes}; - write_idx++; - p += vbytes; - } - } else if (target_wt == WT_32BIT) { - // Fixed 32-bit: each element is 4 bytes - uint8_t const* p = packed_start; - while (p + 4 <= packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); - occurrences[write_idx] = {static_cast(row), elem_offset, 4}; - write_idx++; - p += 4; - } - } else if (target_wt == WT_64BIT) { - // Fixed 64-bit: each element is 8 bytes - uint8_t const* p = packed_start; - while (p + 8 <= packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); - occurrences[write_idx] = {static_cast(row), elem_offset, 8}; - write_idx++; - p += 8; - } - } - } else if (wt == target_wt) { - // Non-packed encoding: single element - int32_t data_offset, data_length; - if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - - int32_t abs_offset = static_cast(cur - bytes - start) + data_offset; - occurrences[write_idx] = {static_cast(row), abs_offset, data_length}; - write_idx++; - } - } - - // Skip to next field - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -// ============================================================================ -// Pass 2: Extract data kernels -// ============================================================================ - -// ============================================================================ -// Data Extraction Location Providers -// ============================================================================ - -struct TopLevelLocationProvider { - cudf::size_type const* offsets; - cudf::size_type base_offset; - field_location const* locations; - int field_idx; - int num_fields; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto loc = locations[thread_idx * num_fields + field_idx]; - if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } - return loc; - } -}; - -struct RepeatedLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - repeated_occurrence const* occurrences; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto occ = occurrences[thread_idx]; - data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; - return {occ.offset, occ.length}; - } -}; - -struct NestedLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - field_location const* parent_locations; - field_location const* child_locations; - int field_idx; - int num_fields; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto ploc = parent_locations[thread_idx]; - auto cloc = child_locations[thread_idx * num_fields + field_idx]; - if (ploc.offset >= 0 && cloc.offset >= 0) { - data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; - } else { - cloc.offset = -1; - } - return cloc; - } -}; - -struct NestedRepeatedLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - field_location const* parent_locations; - repeated_occurrence const* occurrences; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto occ = occurrences[thread_idx]; - auto ploc = parent_locations[occ.row_idx]; - data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; - return {occ.offset, occ.length}; - } -}; - -struct RepeatedMsgChildLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - field_location const* msg_locations; - field_location const* child_locations; - int field_idx; - int num_fields; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto mloc = msg_locations[thread_idx]; - auto cloc = child_locations[thread_idx * num_fields + field_idx]; - if (mloc.offset >= 0 && cloc.offset >= 0) { - data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; - } else { - cloc.offset = -1; - } - return cloc; - } -}; - -template -__global__ void extract_varint_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - OutputType* out, - bool* valid, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - if (loc.offset < 0) { - if (has_default) { - out[idx] = static_cast(default_value); - if (valid) valid[idx] = true; - } else { - if (valid) valid[idx] = false; - } - return; - } - - uint8_t const* cur = message_data + data_offset; - uint8_t const* cur_end = cur + loc.length; - - uint64_t v; - int n; - if (!read_varint(cur, cur_end, v, n)) { - set_error_once(error_flag, ERR_VARINT); - if (valid) valid[idx] = false; - return; - } - - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[idx] = static_cast(v); - if (valid) valid[idx] = true; -} - -template -__global__ void extract_fixed_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - OutputType* out, - bool* valid, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - if (loc.offset < 0) { - if (has_default) { - out[idx] = default_value; - if (valid) valid[idx] = true; - } else { - if (valid) valid[idx] = false; - } - return; - } - - uint8_t const* cur = message_data + data_offset; - OutputType value; - - if constexpr (WT == WT_32BIT) { - if (loc.length < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - if (valid) valid[idx] = false; - return; - } - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - if (loc.length < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - if (valid) valid[idx] = false; - return; - } - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - - out[idx] = value; - if (valid) valid[idx] = true; -} - -template -__global__ void extract_lengths_kernel(LocationProvider loc_provider, - int total_items, - int32_t* out_lengths, - bool has_default = false, - int32_t default_length = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - if (loc.offset >= 0) { - out_lengths[idx] = loc.length; - } else if (has_default) { - out_lengths[idx] = default_length; - } else { - out_lengths[idx] = 0; - } -} -template -__global__ void copy_varlen_data_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - cudf::size_type const* output_offsets, - char* output_chars, - int* error_flag, - bool has_default = false, - uint8_t const* default_chars = nullptr, - int default_len = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - auto out_start = output_offsets[idx]; - - if (loc.offset < 0) { - if (has_default && default_len > 0) { - memcpy(output_chars + out_start, default_chars, default_len); - } - return; - } - - uint8_t const* src = message_data + data_offset; - memcpy(output_chars + out_start, src, loc.length); -} - -/** - * Extract varint field data using pre-recorded locations. - * Supports default values for missing fields. - */ -template -__global__ void extract_varint_from_locations_kernel( - uint8_t const* message_data, - cudf::size_type const* offsets, // List offsets for each row - cudf::size_type base_offset, - field_location const* locations, // [num_rows * num_fields] - int field_idx, - int num_fields, - OutputType* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto loc = locations[row * num_fields + field_idx]; - if (loc.offset < 0) { - // Field not found - use default value if available - if (has_default) { - out[row] = static_cast(default_value); - valid[row] = true; - } else { - valid[row] = false; - } - return; - } - - // Calculate absolute offset in the message data - auto row_start = offsets[row] - base_offset; - uint8_t const* cur = message_data + row_start + loc.offset; - uint8_t const* cur_end = cur + loc.length; - - uint64_t v; - int n; - if (!read_varint(cur, cur_end, v, n)) { - set_error_once(error_flag, ERR_VARINT); - valid[row] = false; - return; - } - - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[row] = static_cast(v); - valid[row] = true; -} - -/** - * Extract fixed-size field data (fixed32, fixed64, float, double). - * Supports default values for missing fields. - */ -template -__global__ void extract_fixed_from_locations_kernel(uint8_t const* message_data, - cudf::size_type const* offsets, - cudf::size_type base_offset, - field_location const* locations, - int field_idx, - int num_fields, - OutputType* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto loc = locations[row * num_fields + field_idx]; - if (loc.offset < 0) { - // Field not found - use default value if available - if (has_default) { - out[row] = default_value; - valid[row] = true; - } else { - valid[row] = false; - } - return; - } - - auto row_start = offsets[row] - base_offset; - uint8_t const* cur = message_data + row_start + loc.offset; - - OutputType value; - if constexpr (WT == WT_32BIT) { - if (loc.length < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - valid[row] = false; - return; - } - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - if (loc.length < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - valid[row] = false; - return; - } - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - - out[row] = value; - valid[row] = true; -} - -// ============================================================================ -// Repeated field extraction kernels -// ============================================================================ - -/** - * Extract repeated varint values using pre-recorded occurrences. - */ -template -__global__ void extract_repeated_varint_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - repeated_occurrence const* occurrences, - int total_occurrences, - OutputType* out, - int* error_flag) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_occurrences) return; - - auto const& occ = occurrences[idx]; - auto row_start = row_offsets[occ.row_idx] - base_offset; - uint8_t const* cur = message_data + row_start + occ.offset; - uint8_t const* cur_end = cur + occ.length; - - uint64_t v; - int n; - if (!read_varint(cur, cur_end, v, n)) { - set_error_once(error_flag, ERR_VARINT); - out[idx] = OutputType{}; - return; - } - - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[idx] = static_cast(v); -} - -/** - * Extract repeated fixed-size values using pre-recorded occurrences. - */ -template -__global__ void extract_repeated_fixed_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - repeated_occurrence const* occurrences, - int total_occurrences, - OutputType* out, - int* error_flag) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_occurrences) return; - - auto const& occ = occurrences[idx]; - auto row_start = row_offsets[occ.row_idx] - base_offset; - uint8_t const* cur = message_data + row_start + occ.offset; - - OutputType value; - if constexpr (WT == WT_32BIT) { - if (occ.length < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - out[idx] = OutputType{}; - return; - } - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - if (occ.length < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - out[idx] = OutputType{}; - return; - } - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - - out[idx] = value; -} - -// ============================================================================ -// Nested message scanning kernels -// ============================================================================ - -/** - * Scan nested message fields. - * Each row represents a nested message at a specific parent location. - * This kernel finds fields within the nested message bytes. - */ -__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - int num_parent_rows, - field_descriptor const* field_descs, - int num_fields, - field_location* output_locations, - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_parent_rows) return; - - for (int f = 0; f < num_fields; f++) { - output_locations[row * num_fields + f] = {-1, 0}; - } - - auto const& parent_loc = parent_locations[row]; - if (parent_loc.offset < 0) { return; } - - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* nested_start = message_data + parent_row_start + parent_loc.offset; - uint8_t const* nested_end = nested_start + parent_loc.length; - - uint8_t const* cur = nested_start; - - while (cur < nested_end) { - proto_tag tag; - if (!decode_tag(cur, nested_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - for (int f = 0; f < num_fields; f++) { - if (field_descs[f].field_number == fn) { - if (wt != field_descs[f].expected_wire_type) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - int data_offset = static_cast(cur - nested_start); - - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, nested_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - if (len > static_cast(nested_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - output_locations[row * num_fields + f] = {data_offset + len_bytes, - static_cast(len)}; - } else { - int field_size = get_wire_type_size(wt, cur, nested_end); - if (field_size < 0) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - output_locations[row * num_fields + f] = {data_offset, field_size}; - } - } - } - - uint8_t const* next; - if (!skip_field(cur, nested_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -/** - * Build a null bitmask from a boolean validity array. - * @param valid Device vector where valid[i] indicates row i validity. - * @return Pair of (null mask buffer, null count). - */ -template -inline std::pair make_null_mask_from_valid( - rmm::device_uvector const& valid, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto begin = thrust::make_counting_iterator(0); - auto end = begin + valid.size(); - auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { - return static_cast(ptr[i]); - }; - return cudf::detail::valid_if(begin, end, pred, stream, mr); -} - -inline void build_offsets_from_lengths(rmm::device_uvector const& lengths, - rmm::device_uvector& offsets, - rmm::cuda_stream_view stream) -{ - CUDF_EXPECTS(offsets.size() == lengths.size() + 1, "offsets size must equal lengths size + 1"); - CUDF_CUDA_TRY(cudaMemsetAsync(offsets.data(), 0, sizeof(int32_t), stream.value())); - if (lengths.size() > 0) { - thrust::inclusive_scan( - rmm::exec_policy(stream), lengths.begin(), lengths.end(), offsets.begin() + 1); - } -} - -template -std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, - int num_rows, - LaunchFn&& launch_extract, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - launch_extract(out.data(), valid.data()); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); -} - -template -// Shared integer extractor for INT32/INT64/UINT32/UINT64 decode paths. -inline void extract_integer_into_buffers(uint8_t const* message_data, - LocationProvider const& loc_provider, - int num_rows, - int blocks, - int threads, - bool has_default, - int64_t default_value, - int encoding, - bool enable_zigzag, - T* out_ptr, - bool* valid_ptr, - int* error_ptr, - rmm::cuda_stream_view stream) -{ - if (enable_zigzag && encoding == spark_rapids_jni::ENC_ZIGZAG) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - default_value); - } else if (encoding == spark_rapids_jni::ENC_FIXED) { - if constexpr (sizeof(T) == 4) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - static_cast(default_value)); - } else { - static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - static_cast(default_value)); - } - } else { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - default_value); - } -} - -template -// Builds a scalar column for integer-like protobuf fields. -std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, - uint8_t const* message_data, - LocationProvider const& loc_provider, - int num_rows, - int blocks, - int threads, - rmm::device_uvector& d_error, - bool has_default, - int64_t default_value, - int encoding, - bool enable_zigzag, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - return extract_and_build_scalar_column( - dt, - num_rows, - [&](T* out_ptr, bool* valid_ptr) { - extract_integer_into_buffers(message_data, - loc_provider, - num_rows, - blocks, - threads, - has_default, - default_value, - encoding, - enable_zigzag, - out_ptr, - valid_ptr, - d_error.data(), - stream); - }, - stream, - mr); -} - -/** - * Scan for child fields within repeated message occurrences. - * Each occurrence is a protobuf message, and we need to find child field locations within it. - */ -__global__ void scan_repeated_message_children_kernel( - uint8_t const* message_data, - int32_t const* msg_row_offsets, // Row offset for each occurrence - field_location const* - msg_locs, // Location of each message occurrence (offset within row, length) - int num_occurrences, - field_descriptor const* child_descs, - int num_child_fields, - field_location* child_locs, // Output: [num_occurrences * num_child_fields] - int* error_flag) -{ - auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (occ_idx >= num_occurrences) return; - - // Initialize child locations to not found - for (int f = 0; f < num_child_fields; f++) { - child_locs[occ_idx * num_child_fields + f] = {-1, 0}; - } - - auto const& msg_loc = msg_locs[occ_idx]; - if (msg_loc.offset < 0) return; - - // Calculate absolute position of this message in the data - int32_t row_offset = msg_row_offsets[occ_idx]; - uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; - uint8_t const* msg_end = msg_start + msg_loc.length; - - uint8_t const* cur = msg_start; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - // Check against child field descriptors - for (int f = 0; f < num_child_fields; f++) { - if (child_descs[f].field_number == fn) { - bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); - if (!is_packed && wt != child_descs[f].expected_wire_type) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - int data_offset = static_cast(cur - msg_start); - - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - // Store offset (after length prefix) and length - child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, - static_cast(len)}; - } else { - // For varint/fixed types, store offset and estimated length - int32_t data_length = 0; - if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { data_length = vbytes; } - } else if (wt == WT_32BIT) { - data_length = 4; - } else if (wt == WT_64BIT) { - data_length = 8; - } - child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; - } - // Don't break - last occurrence wins (protobuf semantics) - } - } - - // Skip to next field - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -/** - * Count repeated field occurrences within nested messages. - * Similar to count_repeated_fields_kernel but operates on nested message locations. - */ -__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - repeated_field_info* repeated_info, - int num_repeated, - int const* repeated_indices, - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - // Initialize counts - for (int ri = 0; ri < num_repeated; ri++) { - repeated_info[row * num_repeated + ri] = {0, 0}; - } - - auto const& parent_loc = parent_locs[row]; - if (parent_loc.offset < 0) return; - - cudf::size_type row_off; - row_off = row_offsets[row] - base_offset; - - uint8_t const* msg_start = message_data + row_off + parent_loc.offset; - uint8_t const* msg_end = msg_start + parent_loc.length; - uint8_t const* cur = msg_start; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - // Check if this is one of our repeated fields - for (int ri = 0; ri < num_repeated; ri++) { - int schema_idx = repeated_indices[ri]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int expected_wt = schema[schema_idx].wire_type; - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - - if (!is_packed && wt != expected_wt) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - if (is_packed) { - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - int count = 0; - if (expected_wt == WT_VARINT) { - uint8_t const* p = packed_start; - while (p < packed_end) { - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - p += vbytes; - count++; - } - } else if (expected_wt == WT_32BIT) { - if ((packed_len % 4) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 4); - } else if (expected_wt == WT_64BIT) { - if ((packed_len % 8) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 8); - } - repeated_info[row * num_repeated + ri].count += count; - repeated_info[row * num_repeated + ri].total_length += static_cast(packed_len); - } else { - int32_t data_offset, data_len; - if (!get_field_data_location(cur, msg_end, wt, data_offset, data_len)) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - repeated_info[row * num_repeated + ri].count++; - repeated_info[row * num_repeated + ri].total_length += data_len; - } - } - } - - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -/** - * Scan for repeated field occurrences within nested messages. - */ -__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - int32_t const* occ_prefix_sums, - int num_repeated, - int const* repeated_indices, - repeated_occurrence* occurrences, - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& parent_loc = parent_locs[row]; - if (parent_loc.offset < 0) return; - - // Prefix sum gives the write start offset for this row. - int occ_offset = occ_prefix_sums[row]; - - cudf::size_type row_off = row_offsets[row] - base_offset; - - uint8_t const* msg_start = message_data + row_off + parent_loc.offset; - uint8_t const* msg_end = msg_start + parent_loc.length; - uint8_t const* cur = msg_start; - - int occ_idx = 0; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - // Check if this is one of our repeated fields. - for (int ri = 0; ri < num_repeated; ri++) { - int schema_idx = repeated_indices[ri]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int expected_wt = schema[schema_idx].wire_type; - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - - if (!is_packed && wt != expected_wt) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - if (is_packed) { - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - if (expected_wt == WT_VARINT) { - uint8_t const* p = packed_start; - while (p < packed_end) { - int32_t elem_offset = static_cast(p - msg_start); - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - occurrences[occ_offset + occ_idx] = {row, elem_offset, vbytes}; - occ_idx++; - p += vbytes; - } - } else if (expected_wt == WT_32BIT) { - if ((packed_len % 4) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - for (uint64_t i = 0; i < packed_len; i += 4) { - occurrences[occ_offset + occ_idx] = { - row, static_cast(packed_start - msg_start + i), 4}; - occ_idx++; - } - } else if (expected_wt == WT_64BIT) { - if ((packed_len % 8) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - for (uint64_t i = 0; i < packed_len; i += 8) { - occurrences[occ_offset + occ_idx] = { - row, static_cast(packed_start - msg_start + i), 8}; - occ_idx++; - } - } - } else { - int32_t data_offset = static_cast(cur - msg_start); - int32_t data_len = 0; - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - data_offset += len_bytes; - data_len = static_cast(len); - } else if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { data_len = vbytes; } - } else if (wt == WT_32BIT) { - data_len = 4; - } else if (wt == WT_64BIT) { - data_len = 8; - } - - occurrences[occ_offset + occ_idx] = {row, data_offset, data_len}; - occ_idx++; - } - } - } - - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - -/** - * Extract varint child fields from repeated message occurrences. - */ -template -__global__ void extract_repeated_msg_child_varint_kernel(uint8_t const* message_data, - int32_t const* msg_row_offsets, - field_location const* msg_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - OutputType* out, - bool* valid, - int num_occurrences, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= num_occurrences) return; - - auto const& msg_loc = msg_locs[idx]; - auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; - - if (msg_loc.offset < 0 || field_loc.offset < 0) { - if (has_default) { - out[idx] = static_cast(default_value); - valid[idx] = true; - } else { - valid[idx] = false; - } - return; - } - - int32_t row_offset = msg_row_offsets[idx]; - uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; - uint8_t const* cur = msg_start + field_loc.offset; - uint8_t const* msg_end = msg_start + msg_loc.length; - uint8_t const* varint_end = - (cur + MAX_VARINT_BYTES < msg_end) ? (cur + MAX_VARINT_BYTES) : msg_end; - - uint64_t val; - int vbytes; - if (!read_varint(cur, varint_end, val, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - valid[idx] = false; - return; - } - - if constexpr (ZigZag) { val = (val >> 1) ^ (-(val & 1)); } - - out[idx] = static_cast(val); - valid[idx] = true; -} - -/** - * Extract fixed-size child fields from repeated message occurrences. - */ -template -__global__ void extract_repeated_msg_child_fixed_kernel(uint8_t const* message_data, - int32_t const* msg_row_offsets, - field_location const* msg_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - OutputType* out, - bool* valid, - int num_occurrences, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= num_occurrences) return; - - auto const& msg_loc = msg_locs[idx]; - auto const& field_loc = child_locs[idx * num_child_fields + child_idx]; - - if (msg_loc.offset < 0 || field_loc.offset < 0) { - if (has_default) { - out[idx] = default_value; - valid[idx] = true; - } else { - valid[idx] = false; - } - return; - } - - int32_t row_offset = msg_row_offsets[idx]; - uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; - uint8_t const* cur = msg_start + field_loc.offset; - - OutputType value; - if constexpr (WT == WT_32BIT) { - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - - out[idx] = value; - valid[idx] = true; -} - -/** - * Kernel to extract string data from repeated message child fields. - * Copies all strings in parallel on the GPU instead of per-string host copies. - */ - -/** - * Helper to build string column for repeated message child fields. - * Uses GPU kernels for parallel string extraction (critical performance fix!). - */ -inline std::unique_ptr build_repeated_msg_child_string_column( - uint8_t const* message_data, - rmm::device_uvector const& d_msg_row_offsets, - rmm::device_uvector const& d_msg_locs, - rmm::device_uvector const& d_child_locs, - int child_idx, - int num_child_fields, - int total_count, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - if (total_count == 0) { return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); } - - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; - - // Compute string lengths on GPU using child_locs directly - rmm::device_uvector d_lengths(total_count, stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_lengths.data(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - auto const& loc = child_locs[idx * ncf + ci]; - return loc.offset >= 0 ? loc.length : 0; - }); - - // Compute offsets without host round-trip - rmm::device_uvector d_str_offsets(total_count + 1, stream, mr); - build_offsets_from_lengths(d_lengths, d_str_offsets, stream); - int32_t total_chars = - thrust::reduce(rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), 0); - - // Allocate output chars and validity - rmm::device_uvector d_chars(total_chars, stream, mr); - rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); - - // Set validity for all entries - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_valid.data(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - return child_locs[idx * ncf + ci].offset >= 0; - }); - - // Extract all strings in parallel on GPU - if (total_chars > 0) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - child_idx, - num_child_fields}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - total_count, - d_str_offsets.data(), - d_chars.data(), - d_error.data()); - } - - auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); - - auto str_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - d_str_offsets.release(), - rmm::device_buffer{}, - 0); - return cudf::make_strings_column( - total_count, std::move(str_offsets_col), d_chars.release(), null_count, std::move(mask)); -} - -inline std::unique_ptr build_repeated_msg_child_bytes_column( - uint8_t const* message_data, - rmm::device_uvector const& d_msg_row_offsets, - rmm::device_uvector const& d_msg_locs, - rmm::device_uvector const& d_child_locs, - int child_idx, - int num_child_fields, - int total_count, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - if (total_count == 0) { - auto empty_offsets = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(empty_offsets->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - auto empty_bytes = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); - return cudf::make_lists_column( - 0, std::move(empty_offsets), std::move(empty_bytes), 0, rmm::device_buffer{}, stream, mr); - } - - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; - - rmm::device_uvector d_lengths(total_count, stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_lengths.data(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - auto const& loc = child_locs[idx * ncf + ci]; - return loc.offset >= 0 ? loc.length : 0; - }); - - rmm::device_uvector d_offs(total_count + 1, stream, mr); - build_offsets_from_lengths(d_lengths, d_offs, stream); - int32_t total_bytes = - thrust::reduce(rmm::exec_policy(stream), d_lengths.begin(), d_lengths.end(), 0); - - rmm::device_uvector d_bytes(total_bytes, stream, mr); - rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); - - // Set validity for all entries - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - d_valid.data(), - [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - return child_locs[idx * ncf + ci].offset >= 0; - }); - - if (total_bytes > 0) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - child_idx, - num_child_fields}; - copy_varlen_data_kernel - <<>>( - message_data, loc_provider, total_count, d_offs.data(), d_bytes.data(), d_error.data()); - } - - auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); - auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - d_offs.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = - std::make_unique(cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(d_bytes.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); - return cudf::make_lists_column(total_count, - std::move(offs_col), - std::move(bytes_child), - null_count, - std::move(mask), - stream, - mr); -} - -/** - * Kernel to compute nested struct locations from child field locations. - * Replaces host-side loop that was copying data D->H, processing, then H->D. - * This is a critical performance optimization. - */ -__global__ void compute_nested_struct_locations_kernel( - field_location const* child_locs, // Child field locations from parent scan - field_location const* msg_locs, // Parent message locations - int32_t const* msg_row_offsets, // Parent message row offsets - int child_idx, // Which child field is the nested struct - int num_child_fields, // Total number of child fields per occurrence - field_location* nested_locs, // Output: nested struct locations - int32_t* nested_row_offsets, // Output: nested struct row offsets - int total_count) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; - - // Get the nested struct location from child_locs - nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; - // Compute absolute row offset = msg_row_offset + msg_offset - nested_row_offsets[idx] = msg_row_offsets[idx] + msg_locs[idx].offset; -} - -/** - * Kernel to compute absolute grandchild parent locations from parent and child locations. - * Computes: gc_parent_abs[i] = parent[i].offset + child[i * ncf + ci].offset - * This replaces host-side loop with D->H->D copy pattern. - */ -__global__ void compute_grandchild_parent_locations_kernel( - field_location const* parent_locs, // Parent locations (row count) - field_location const* child_locs, // Child locations (row * num_child_fields) - int child_idx, // Which child field - int num_child_fields, // Total child fields per row - field_location* gc_parent_abs, // Output: absolute grandchild parent locations - int num_rows) -{ - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= num_rows) return; - - auto const& parent_loc = parent_locs[row]; - auto const& child_loc = child_locs[row * num_child_fields + child_idx]; - - if (parent_loc.offset >= 0 && child_loc.offset >= 0) { - // Absolute offset = parent offset + child's relative offset - gc_parent_abs[row].offset = parent_loc.offset + child_loc.offset; - gc_parent_abs[row].length = child_loc.length; - } else { - gc_parent_abs[row] = {-1, 0}; - } -} - -/** - * Compute virtual parent row offsets and locations for repeated message occurrences - * inside nested messages. Each occurrence becomes a virtual "row" so that - * build_nested_struct_column can recursively process the children. - */ -__global__ void compute_virtual_parents_for_nested_repeated_kernel( - repeated_occurrence const* occurrences, - cudf::size_type const* row_list_offsets, // original binary input list offsets - field_location const* parent_locations, // parent nested message locations - cudf::size_type* virtual_row_offsets, // output: [total_count] - field_location* virtual_parent_locs, // output: [total_count] - int total_count) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; - - auto const& occ = occurrences[idx]; - auto const& ploc = parent_locations[occ.row_idx]; - - virtual_row_offsets[idx] = row_list_offsets[occ.row_idx]; - - // Keep zero-length embedded messages as "present but empty". - // Protobuf allows an embedded message with length=0, which maps to a non-null - // struct with all-null children (not a null struct). - if (ploc.offset >= 0) { - virtual_parent_locs[idx] = {ploc.offset + occ.offset, occ.length}; - } else { - virtual_parent_locs[idx] = {-1, 0}; - } -} - -/** - * Kernel to compute message locations and row offsets from repeated occurrences. - * Replaces host-side loop that processed occurrences. - */ -__global__ void compute_msg_locations_from_occurrences_kernel( - repeated_occurrence const* occurrences, // Repeated field occurrences - cudf::size_type const* list_offsets, // List offsets for rows - cudf::size_type base_offset, // Base offset to subtract - field_location* msg_locs, // Output: message locations - int32_t* msg_row_offsets, // Output: message row offsets - int total_count) -{ - int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; - - auto const& occ = occurrences[idx]; - msg_row_offsets[idx] = static_cast(list_offsets[occ.row_idx] - base_offset); - msg_locs[idx] = {occ.offset, occ.length}; -} - -/** - * Extract a single field's locations from a 2D strided array on the GPU. - * Replaces a D2H + CPU loop + H2D pattern for nested message location extraction. - */ -__global__ void extract_strided_locations_kernel(field_location const* nested_locations, - int field_idx, - int num_fields, - field_location* parent_locs, - int num_rows) -{ - int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= num_rows) return; - parent_locs[row] = nested_locations[row * num_fields + field_idx]; -} - -/** - * Functor to extract count from repeated_field_info with strided access. - * Used for extracting counts for a specific repeated field from 2D array. - */ -struct extract_strided_count { - repeated_field_info const* info; - int field_idx; - int num_fields; - - __device__ int32_t operator()(int row) const { return info[row * num_fields + field_idx].count; } -}; - -/** - * Extract varint from nested message locations. - */ -template -__global__ void extract_nested_varint_kernel(uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - OutputType* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; - - if (parent_loc.offset < 0 || field_loc.offset < 0) { - if (has_default) { - out[row] = static_cast(default_value); - valid[row] = true; - } else { - valid[row] = false; - } - return; - } - - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; - uint8_t const* cur_end = cur + field_loc.length; - - uint64_t v; - int n; - if (!read_varint(cur, cur_end, v, n)) { - set_error_once(error_flag, ERR_VARINT); - valid[row] = false; - return; - } - - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[row] = static_cast(v); - valid[row] = true; -} - -/** - * Extract fixed-size from nested message locations. - */ -template -__global__ void extract_nested_fixed_kernel(uint8_t const* message_data, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - field_location const* field_locations, - int field_idx, - int num_fields, - OutputType* out, - bool* valid, - int num_rows, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - auto const& parent_loc = parent_locations[row]; - auto const& field_loc = field_locations[row * num_fields + field_idx]; - - if (parent_loc.offset < 0 || field_loc.offset < 0) { - if (has_default) { - out[row] = default_value; - valid[row] = true; - } else { - valid[row] = false; - } - return; - } - - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* cur = message_data + parent_row_start + parent_loc.offset + field_loc.offset; - - OutputType value; - if constexpr (WT == WT_32BIT) { - if (field_loc.length < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - valid[row] = false; - return; - } - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - if (field_loc.length < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - valid[row] = false; - return; - } - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - - out[row] = value; - valid[row] = true; -} - -/** - * Copy nested variable-length data (string/bytes). - */ - -/** - * Copy scalar string field data. - * For top-level STRING fields (not nested within a struct). - */ - -// ============================================================================ -// Utility functions -// ============================================================================ - -// Note: make_null_mask_from_valid is defined earlier in the file (before -// scan_repeated_message_children_kernel) - -/** - * Create an all-null column of the specified type. - */ -std::unique_ptr make_null_column(cudf::data_type dtype, - cudf::size_type num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - if (num_rows == 0) { return cudf::make_empty_column(dtype); } - - switch (dtype.id()) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT8: - case cudf::type_id::UINT8: - case cudf::type_id::INT16: - case cudf::type_id::UINT16: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - case cudf::type_id::FLOAT32: - case cudf::type_id::FLOAT64: { - auto data = rmm::device_buffer(cudf::size_of(dtype) * num_rows, stream, mr); - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); - return std::make_unique( - dtype, num_rows, std::move(data), std::move(null_mask), num_rows); - } - case cudf::type_id::STRING: { - // Create empty strings column with all nulls - rmm::device_uvector pairs(num_rows, stream, mr); - thrust::fill(rmm::exec_policy(stream), - pairs.data(), - pairs.end(), - cudf::strings::detail::string_index_pair{nullptr, 0}); - return cudf::strings::detail::make_strings_column(pairs.data(), pairs.end(), stream, mr); - } - case cudf::type_id::LIST: { - // Create LIST with all nulls - // Offsets: all zeros (empty lists) - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); - - // Empty child column - use UINT8 for BinaryType consistency - // This works because the list has 0 elements, so the child type doesn't matter for nulls - auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); - - // All null mask - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); - - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - num_rows, - std::move(null_mask), - stream, - mr); - } - case cudf::type_id::STRUCT: { - // TODO(protobuf): This creates an empty STRUCT with no children, which does not - // match the expected nested schema. This is a crash-prevention workaround for - // unprocessed struct fields at deep nesting levels. A proper fix would recurse - // into the schema to build the correct child column structure with all-null leaves. - std::vector> empty_children; - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); - return cudf::make_structs_column( - num_rows, std::move(empty_children), num_rows, std::move(null_mask), stream, mr); - } - default: CUDF_FAIL("Unsupported type for null column creation"); - } -} - -/** - * Create an empty column (0 rows) of the specified type. - * This handles nested types (LIST, STRUCT) that cudf::make_empty_column doesn't support. - */ -std::unique_ptr make_empty_column_safe(cudf::data_type dtype, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - switch (dtype.id()) { - case cudf::type_id::LIST: { - // Create empty list column with empty UINT8 child (Spark BinaryType maps to LIST) - auto offsets_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - // Initialize offset to 0 - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - auto child_col = std::make_unique( - cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); - return cudf::make_lists_column( - 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); - } - case cudf::type_id::STRUCT: { - // Create empty struct column with no children - std::vector> empty_children; - return cudf::make_structs_column( - 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); - } - default: - // For non-nested types, use cudf's make_empty_column - return cudf::make_empty_column(dtype); - } -} - -/** - * Create an all-null LIST column with the provided child column. - * The child column is expected to have 0 rows. - */ -std::unique_ptr make_null_list_column_with_child( - std::unique_ptr child_col, - cudf::size_type num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); - auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - num_rows, - std::move(null_mask), - stream, - mr); -} - -/** - * Wrap a 0-row element column into a 0-row LIST column. - */ -std::unique_ptr make_empty_list_column(std::unique_ptr element_col, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - 1, - rmm::device_buffer(sizeof(int32_t), stream, mr), - rmm::device_buffer{}, - 0); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - return cudf::make_lists_column( - 0, std::move(offsets_col), std::move(element_col), 0, rmm::device_buffer{}, stream, mr); -} - -/** - * Find all child field indices for a given parent index in the schema. - * This is a commonly used pattern throughout the codebase. - * - * @param schema The schema vector (either nested_field_descriptor or - * device_nested_field_descriptor) - * @param num_fields Number of fields in the schema - * @param parent_idx The parent index to search for - * @return Vector of child field indices - */ -template -std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) -{ - std::vector child_indices; - for (int i = 0; i < num_fields; i++) { - if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } - } - return child_indices; -} - -/** - * Recursively create an empty struct column with proper nested structure based on schema. - * This handles STRUCT children that contain their own grandchildren. - * - * @param schema The schema vector - * @param schema_output_types Output types for each schema field - * @param parent_idx Index of the parent field (whose children we want to create) - * @param num_fields Total number of fields in schema - * @param stream CUDA stream - * @param mr Memory resource - * @return Empty struct column with proper nested structure - */ -template -std::unique_ptr make_empty_struct_column_with_schema( - SchemaT const& schema, - std::vector const& schema_output_types, - int parent_idx, - int num_fields, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); - - std::vector> children; - for (int child_idx : child_indices) { - auto child_type = schema_output_types[child_idx]; - - std::unique_ptr child_col; - if (child_type.id() == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_idx, num_fields, stream, mr); - } else { - child_col = make_empty_column_safe(child_type, stream, mr); - } - - if (schema[child_idx].is_repeated) { - child_col = make_empty_list_column(std::move(child_col), stream, mr); - } - - children.push_back(std::move(child_col)); - } - - return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); -} - -} // namespace - -namespace { - -// ============================================================================ -// Kernel to check required fields after scan pass -// ============================================================================ - -/** - * Check if any required fields are missing (offset < 0) and set error flag. - * This is called after the scan pass to validate required field constraints. - */ -__global__ void check_required_fields_kernel( - field_location const* locations, // [num_rows * num_fields] - uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) - int num_fields, - int num_rows, - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - for (int f = 0; f < num_fields; f++) { - if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { - // Required field is missing - set error flag - set_error_once(error_flag, ERR_REQUIRED); - return; // No need to check other fields for this row - } - } -} - -/** - * Validate enum values against a set of valid values. - * If a value is not in the valid set: - * 1. Mark the field as invalid (valid[row] = false) - * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) - * - * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is - * encountered, the entire struct row is set to null (not just the enum field). - * - * The valid_values array must be sorted for binary search. - * - * @note Time complexity: O(log(num_valid_values)) per row. - */ -__global__ void validate_enum_values_kernel( - int32_t const* values, // [num_rows] extracted enum values - bool* valid, // [num_rows] field validity flags (will be modified) - bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) - int32_t const* valid_enum_values, // sorted array of valid enum values - int num_valid_values, // size of valid_enum_values - int num_rows) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - // Skip if already invalid (field was missing) - missing field is not an enum error - if (!valid[row]) return; - - int32_t val = values[row]; - - // Binary search for the value in valid_enum_values - int left = 0; - int right = num_valid_values - 1; - bool found = false; - - while (left <= right) { - int mid = left + (right - left) / 2; - if (valid_enum_values[mid] == val) { - found = true; - break; - } else if (valid_enum_values[mid] < val) { - left = mid + 1; - } else { - right = mid - 1; - } - } - - // If not found, mark as invalid - if (!found) { - valid[row] = false; - // Also mark the row as having an invalid enum - this will null the entire struct row - row_has_invalid_enum[row] = true; - } -} - -/** - * Compute output UTF-8 length for enum-as-string rows. - * Invalid/missing values produce length 0 (null row/field semantics handled by valid[] and - * row_has_invalid_enum). - */ -__global__ void compute_enum_string_lengths_kernel( - int32_t const* values, // [num_rows] enum numeric values - bool const* valid, // [num_rows] field validity - int32_t const* valid_enum_values, // sorted enum numeric values - int32_t const* enum_name_offsets, // [num_valid_values + 1] - int num_valid_values, - int32_t* lengths, // [num_rows] - int num_rows) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - - if (!valid[row]) { - lengths[row] = 0; - return; - } - - int32_t val = values[row]; - int left = 0; - int right = num_valid_values - 1; - while (left <= right) { - int mid = left + (right - left) / 2; - int32_t mid_val = valid_enum_values[mid]; - if (mid_val == val) { - lengths[row] = enum_name_offsets[mid + 1] - enum_name_offsets[mid]; - return; - } else if (mid_val < val) { - left = mid + 1; - } else { - right = mid - 1; - } - } - - // Should not happen when validate_enum_values_kernel has already run, but keep safe. - lengths[row] = 0; -} - -/** - * Copy enum-as-string UTF-8 bytes into output chars buffer using precomputed row offsets. - */ -__global__ void copy_enum_string_chars_kernel( - int32_t const* values, // [num_rows] enum numeric values - bool const* valid, // [num_rows] field validity - int32_t const* valid_enum_values, // sorted enum numeric values - int32_t const* enum_name_offsets, // [num_valid_values + 1] - uint8_t const* enum_name_chars, // concatenated enum UTF-8 names - int num_valid_values, - int32_t const* output_offsets, // [num_rows + 1] - char* out_chars, // [total_chars] - int num_rows) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - if (!valid[row]) return; - - int32_t val = values[row]; - int left = 0; - int right = num_valid_values - 1; - while (left <= right) { - int mid = left + (right - left) / 2; - int32_t mid_val = valid_enum_values[mid]; - if (mid_val == val) { - int32_t src_begin = enum_name_offsets[mid]; - int32_t src_end = enum_name_offsets[mid + 1]; - int32_t dst_begin = output_offsets[row]; - for (int32_t i = 0; i < (src_end - src_begin); ++i) { - out_chars[dst_begin + i] = static_cast(enum_name_chars[src_begin + i]); - } - return; - } else if (mid_val < val) { - left = mid + 1; - } else { - right = mid - 1; - } - } -} - -std::unique_ptr build_enum_string_column( - rmm::device_uvector& enum_values, - rmm::device_uvector& valid, - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1) / threads); - - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - validate_enum_values_kernel<<>>( - enum_values.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_rows); - - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int32_t total_name_chars = 0; - for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - h_name_offsets[k + 1] = total_name_chars; - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& name : enum_name_bytes) { - if (!name.empty()) { - std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); - cursor += static_cast(name.size()); - } - } - - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_values.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - lengths.data(), - num_rows); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - build_offsets_from_lengths(lengths, output_offsets, stream); - int32_t total_chars = thrust::reduce(rmm::exec_policy(stream), lengths.begin(), lengths.end(), 0); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - enum_values.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - output_offsets.data(), - chars.data(), - num_rows); - } - - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - return cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); -} - -template -std::unique_ptr extract_and_build_string_or_bytes_column( - bool as_bytes, - uint8_t const* message_data, - int num_rows, - LengthProvider const& length_provider, - CopyProvider const& copy_provider, - ValidityFn validity_fn, - bool has_default, - std::vector const& default_bytes, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; - rmm::device_uvector d_default(def_len, stream, mr); - if (has_default && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_default.data(), default_bytes.data(), def_len, cudaMemcpyHostToDevice, stream.value())); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (num_rows + threads - 1) / threads; - extract_lengths_kernel<<>>( - length_provider, num_rows, lengths.data(), has_default, def_len); - - rmm::device_uvector output_offsets(num_rows + 1, stream, mr); - build_offsets_from_lengths(lengths, output_offsets, stream); - int32_t total_size = thrust::reduce(rmm::exec_policy(stream), lengths.begin(), lengths.end(), 0); - - rmm::device_uvector chars(total_size, stream, mr); - if (total_size > 0) { - copy_varlen_data_kernel - <<>>(message_data, - copy_provider, - num_rows, - output_offsets.data(), - chars.data(), - d_error.data(), - has_default, - d_default.data(), - def_len); - } - - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.data(), - validity_fn); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - output_offsets.release(), - rmm::device_buffer{}, - 0); - if (as_bytes) { - auto bytes_child = - std::make_unique(cudf::data_type{cudf::type_id::UINT8}, - total_size, - rmm::device_buffer(chars.data(), total_size, stream, mr), - rmm::device_buffer{}, - 0); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(bytes_child), - null_count, - std::move(mask), - stream, - mr); - } - - return cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); -} - -} // namespace +using namespace spark_rapids_jni::protobuf_detail; namespace spark_rapids_jni { -namespace { - -template -std::unique_ptr extract_typed_column( - cudf::data_type dt, - int encoding, - uint8_t const* message_data, - LocationProvider const& loc_provider, - int num_items, - int blocks, - int threads_per_block, - bool has_default, - int64_t default_int, - double default_float, - bool default_bool, - std::vector const& default_string, - int schema_idx, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - switch (dt.id()) { - case cudf::type_id::BOOL8: { - int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; - return extract_and_build_scalar_column( - dt, - num_items, - [&](uint8_t* out_ptr, bool* valid_ptr) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_items, - out_ptr, - valid_ptr, - d_error.data(), - has_default, - def_val); - }, - stream, - mr); - } - case cudf::type_id::INT32: { - rmm::device_uvector out(num_items, stream, mr); - rmm::device_uvector valid((num_items > 0 ? num_items : 1), stream, mr); - extract_integer_into_buffers(message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - has_default, - default_int, - encoding, - true, - out.data(), - valid.data(), - d_error.data(), - stream); - if (schema_idx < static_cast(enum_valid_values.size())) { - auto const& valid_enums = enum_valid_values[schema_idx]; - if (!valid_enums.empty()) { - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_items); - } - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - return std::make_unique( - dt, num_items, out.release(), std::move(mask), null_count); - } - case cudf::type_id::UINT32: - return extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - d_error, - has_default, - default_int, - encoding, - false, - stream, - mr); - case cudf::type_id::INT64: - return extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - d_error, - has_default, - default_int, - encoding, - true, - stream, - mr); - case cudf::type_id::UINT64: - return extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - d_error, - has_default, - default_int, - encoding, - false, - stream, - mr); - case cudf::type_id::FLOAT32: { - float def_float_val = has_default ? static_cast(default_float) : 0.0f; - return extract_and_build_scalar_column( - dt, - num_items, - [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_items, - out_ptr, - valid_ptr, - d_error.data(), - has_default, - def_float_val); - }, - stream, - mr); - } - case cudf::type_id::FLOAT64: { - double def_double = has_default ? default_float : 0.0; - return extract_and_build_scalar_column( - dt, - num_items, - [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_items, - out_ptr, - valid_ptr, - d_error.data(), - has_default, - def_double); - }, - stream, - mr); - } - default: return make_null_column(dt, num_items, stream, mr); - } -} - -/** - * Helper to build a repeated scalar column (LIST of scalar type). - */ -template -std::unique_ptr build_repeated_scalar_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - device_nested_field_descriptor const& field_desc, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto const input_null_count = binary_input.null_count(); - - if (total_count == 0) { - // All rows have count=0, but we still need to check input nulls - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); - auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) - ? cudf::type_id::UINT8 - : static_cast(field_desc.output_type_id); - auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); - - if (input_null_count > 0) { - // Copy input null mask - only input nulls produce output nulls - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - input_null_count, - std::move(null_mask), - stream, - mr); - } else { - // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); - } - } - - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); - - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector values(total_count, stream, mr); - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; - - int encoding = field_desc.encoding; - bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); - - // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) - // For integer types, use fixed kernel only if encoding is ENC_FIXED - constexpr bool is_floating_point = std::is_same_v || std::is_same_v; - bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); - - RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - if (use_fixed_kernel) { - if constexpr (sizeof(T) == 4) { - extract_fixed_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } else { - extract_fixed_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } - } else if (zigzag) { - extract_varint_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } else { - extract_varint_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - auto child_col = std::make_unique( - cudf::data_type{static_cast(field_desc.output_type_id)}, - total_count, - values.release(), - rmm::device_buffer{}, - 0); - - // Only rows where INPUT is null should produce null output - // Rows with valid input but count=0 should produce empty array [] - if (input_null_count > 0) { - // Copy input null mask - only input nulls produce output nulls - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - input_null_count, - std::move(null_mask), - stream, - mr); - } - - return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); -} - -/** - * Build a repeated string/bytes column (LIST of STRING or LIST). - */ -std::unique_ptr build_repeated_string_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - device_nested_field_descriptor const& field_desc, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - bool is_bytes, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto const input_null_count = binary_input.null_count(); - - if (total_count == 0) { - // All rows have count=0, but we still need to check input nulls - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); - auto child_col = is_bytes ? make_empty_column_safe( - cudf::data_type{cudf::type_id::LIST}, stream, mr) // LIST - : cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); - - if (input_null_count > 0) { - // Copy input null mask - only input nulls produce output nulls - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - input_null_count, - std::move(null_mask), - stream, - mr); - } else { - // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); - } - } - - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); - - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Extract string lengths from occurrences - rmm::device_uvector str_lengths(total_count, stream, mr); - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; - RepeatedLocationProvider loc_provider{nullptr, 0, d_occurrences.data()}; - extract_lengths_kernel - <<>>(loc_provider, total_count, str_lengths.data()); - - // Compute string offsets via prefix sum - rmm::device_uvector str_offsets(total_count + 1, stream, mr); - build_offsets_from_lengths(str_lengths, str_offsets, stream); - int32_t total_chars = - thrust::reduce(rmm::exec_policy(stream), str_lengths.begin(), str_lengths.end(), 0); - - // Copy string data - rmm::device_uvector chars(total_chars, stream, mr); - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); - if (total_chars > 0) { - RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - copy_varlen_data_kernel<<>>( - message_data, loc_provider, total_count, str_offsets.data(), chars.data(), d_error.data()); - } - - // Build the child column (either STRING or LIST) - std::unique_ptr child_col; - if (is_bytes) { - // Build LIST for bytes (Spark BinaryType maps to LIST) - auto str_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - str_offsets.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = - std::make_unique(cudf::data_type{cudf::type_id::UINT8}, - total_chars, - rmm::device_buffer(chars.data(), total_chars, stream, mr), - rmm::device_buffer{}, - 0); - child_col = cudf::make_lists_column(total_count, - std::move(str_offsets_col), - std::move(bytes_child), - 0, - rmm::device_buffer{}, - stream, - mr); - } else { - // Build STRING column - auto str_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - str_offsets.release(), - rmm::device_buffer{}, - 0); - child_col = cudf::make_strings_column( - total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); - } - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - - // Only rows where INPUT is null should produce null output - // Rows with valid input but count=0 should produce empty array [] - if (input_null_count > 0) { - // Copy input null mask - only input nulls produce output nulls - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - input_null_count, - std::move(null_mask), - stream, - mr); - } - - return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); -} - -// Forward declaration -- build_nested_struct_column is defined after build_repeated_struct_column -// but the latter's STRUCT-child case needs to call it. -std::unique_ptr build_nested_struct_column( - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - rmm::device_uvector const& d_parent_locs, - std::vector const& child_field_indices, - std::vector const& schema, - int num_fields, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, - rmm::device_uvector& d_error, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int depth); - -// Forward declaration -- build_repeated_child_list_column is defined after -// build_nested_struct_column but both build_repeated_struct_column and build_nested_struct_column -// need to call it. -std::unique_ptr build_repeated_child_list_column( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_parent_rows, - int child_schema_idx, - std::vector const& schema, - int num_fields, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int depth); - -/** - * Build a repeated struct column (LIST of STRUCT). - * This handles repeated message fields like: repeated Item items = 2; - * The output is ArrayType(StructType(...)) - */ -std::unique_ptr build_repeated_struct_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - device_nested_field_descriptor const& field_desc, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - std::vector const& h_device_schema, - std::vector const& child_field_indices, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector const& schema, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, - rmm::device_uvector& d_error_top, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto const input_null_count = binary_input.null_count(); - int num_child_fields = static_cast(child_field_indices.size()); - - if (total_count == 0 || num_child_fields == 0) { - // All rows have count=0 or no child fields - return list of empty structs - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); - - // Build empty struct child column with proper nested structure - int num_schema_fields = static_cast(h_device_schema.size()); - std::vector> empty_struct_children; - for (int child_schema_idx : child_field_indices) { - auto child_type = schema_output_types[child_schema_idx]; - std::unique_ptr child_col; - if (child_type.id() == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - h_device_schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr); - } else { - child_col = make_empty_column_safe(child_type, stream, mr); - } - if (h_device_schema[child_schema_idx].is_repeated) { - child_col = make_empty_list_column(std::move(child_col), stream, mr); - } - empty_struct_children.push_back(std::move(child_col)); - } - auto empty_struct = cudf::make_structs_column( - 0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); - - if (input_null_count > 0) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(empty_struct), - input_null_count, - std::move(null_mask), - stream, - mr); - } else { - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(empty_struct), - 0, - rmm::device_buffer{}, - stream, - mr); - } - } - - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); - - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - // Build child field descriptors for scanning within each message occurrence - std::vector h_child_descs(num_child_fields); - for (int ci = 0; ci < num_child_fields; ci++) { - int child_schema_idx = child_field_indices[ci]; - h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; - h_child_descs[ci].expected_wire_type = h_device_schema[child_schema_idx].wire_type; - } - rmm::device_uvector d_child_descs(num_child_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), - h_child_descs.data(), - num_child_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); - - // For each occurrence, we need to scan for child fields - // Create "virtual" parent locations from the occurrences using GPU kernel - // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) - rmm::device_uvector d_msg_locs(total_count, stream, mr); - rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); - rmm::device_uvector d_msg_row_offsets_size(total_count, stream, mr); - { - auto const occ_threads = THREADS_PER_BLOCK; - auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; - compute_msg_locations_from_occurrences_kernel<<>>( - d_occurrences.data(), - list_offsets, - base_offset, - d_msg_locs.data(), - d_msg_row_offsets.data(), - total_count); - } - thrust::transform(rmm::exec_policy(stream), - d_msg_row_offsets.data(), - d_msg_row_offsets.end(), - d_msg_row_offsets_size.data(), - [] __device__(int32_t v) { return static_cast(v); }); - - // Scan for child fields within each message occurrence - rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); - // Reuse top-level error flag so failfast can observe nested repeated-message failures. - auto& d_error = d_error_top; - - auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; - - // Use a custom kernel to scan child fields within message occurrences - // This is similar to scan_nested_message_fields_kernel but operates on occurrences - scan_repeated_message_children_kernel<<>>( - message_data, - d_msg_row_offsets.data(), - d_msg_locs.data(), - total_count, - d_child_descs.data(), - num_child_fields, - d_child_locs.data(), - d_error.data()); - - // Note: We no longer need to copy child_locs to host because: - // 1. All scalar extraction kernels access d_child_locs directly on device - // 2. String extraction uses GPU kernels - // 3. Nested struct locations are computed on GPU via compute_nested_struct_locations_kernel - - // Extract child field values - build one column per child field - std::vector> struct_children; - int num_schema_fields = static_cast(h_device_schema.size()); - for (int ci = 0; ci < num_child_fields; ci++) { - int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; - auto const enc = h_device_schema[child_schema_idx].encoding; - bool has_def = h_device_schema[child_schema_idx].has_default_value; - bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; - - if (child_is_repeated) { - struct_children.push_back(build_repeated_child_list_column(message_data, - d_msg_row_offsets_size.data(), - 0, - d_msg_locs.data(), - total_count, - child_schema_idx, - schema, - num_schema_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error_top, - stream, - mr, - 1)); - continue; - } - - switch (dt.id()) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - case cudf::type_id::FLOAT32: - case cudf::type_id::FLOAT64: { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; - struct_children.push_back( - extract_typed_column(dt, - enc, - message_data, - loc_provider, - total_count, - blocks, - threads, - has_def, - has_def ? default_ints[child_schema_idx] : 0, - has_def ? default_floats[child_schema_idx] : 0.0, - has_def ? default_bools[child_schema_idx] : false, - default_strings[child_schema_idx], - child_schema_idx, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - stream, - mr)); - break; - } - case cudf::type_id::STRING: { - // For strings, we need a two-pass approach: first get lengths, then copy data - struct_children.push_back(build_repeated_msg_child_string_column(message_data, - d_msg_row_offsets, - d_msg_locs, - d_child_locs, - ci, - num_child_fields, - total_count, - d_error, - stream, - mr)); - break; - } - case cudf::type_id::LIST: { - // bytes (BinaryType) child inside repeated message - struct_children.push_back(build_repeated_msg_child_bytes_column(message_data, - d_msg_row_offsets, - d_msg_locs, - d_child_locs, - ci, - num_child_fields, - total_count, - d_error, - stream, - mr)); - break; - } - case cudf::type_id::STRUCT: { - // Nested struct inside repeated message - use recursive build_nested_struct_column - int num_schema_fields = static_cast(h_device_schema.size()); - auto grandchild_indices = - find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); - - if (grandchild_indices.empty()) { - struct_children.push_back( - cudf::make_structs_column(total_count, - std::vector>{}, - 0, - rmm::device_buffer{}, - stream, - mr)); - } else { - // Compute virtual parent locations for each occurrence's nested struct child - rmm::device_uvector d_nested_locs(total_count, stream, mr); - rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); - { - // Convert int32_t row offsets to cudf::size_type and compute nested struct locations - rmm::device_uvector d_nested_row_offsets_i32(total_count, stream, mr); - compute_nested_struct_locations_kernel<<>>( - d_child_locs.data(), - d_msg_locs.data(), - d_msg_row_offsets.data(), - ci, - num_child_fields, - d_nested_locs.data(), - d_nested_row_offsets_i32.data(), - total_count); - // Add base_offset back so build_nested_struct_column can subtract it - thrust::transform(rmm::exec_policy(stream), - d_nested_row_offsets_i32.data(), - d_nested_row_offsets_i32.end(), - d_nested_row_offsets.data(), - [base_offset] __device__(int32_t v) { - return static_cast(v) + base_offset; - }); - } - - struct_children.push_back(build_nested_struct_column(message_data, - d_nested_row_offsets.data(), - base_offset, - d_nested_locs, - grandchild_indices, - schema, - num_schema_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error_top, - total_count, - stream, - mr, - 0)); - } - break; - } - default: - // Unsupported child type - create null column - struct_children.push_back(make_null_column(dt, total_count, stream, mr)); - break; - } - } - - // Build the struct column from child columns - auto struct_col = cudf::make_structs_column( - total_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); - - // Build the list offsets column - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - - // Build the final LIST column - if (input_null_count > 0) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(struct_col), - input_null_count, - std::move(null_mask), - stream, - mr); - } - - return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); -} - -/** - * Recursively build a nested STRUCT column from parent message locations. - * This supports arbitrarily deep protobuf nesting (bounded by MAX_NESTED_STRUCT_DECODE_DEPTH). - */ -std::unique_ptr build_nested_struct_column( - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - rmm::device_uvector const& d_parent_locs, - std::vector const& child_field_indices, - std::vector const& schema, - int num_fields, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, - rmm::device_uvector& d_error, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int depth) -{ - CUDF_EXPECTS(depth < MAX_NESTED_STRUCT_DECODE_DEPTH, - "Nested protobuf struct depth exceeds supported decode recursion limit"); - - if (num_rows == 0) { - std::vector> empty_children; - for (int child_schema_idx : child_field_indices) { - auto child_type = schema_output_types[child_schema_idx]; - std::unique_ptr child_col; - if (child_type.id() == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_fields, stream, mr); - } else { - child_col = make_empty_column_safe(child_type, stream, mr); - } - if (schema[child_schema_idx].is_repeated) { - child_col = make_empty_list_column(std::move(child_col), stream, mr); - } - empty_children.push_back(std::move(child_col)); - } - return cudf::make_structs_column( - 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); - } - - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1) / threads); - int num_child_fields = static_cast(child_field_indices.size()); - - std::vector h_child_field_descs(num_child_fields); - for (int i = 0; i < num_child_fields; i++) { - int child_idx = child_field_indices[i]; - h_child_field_descs[i].field_number = schema[child_idx].field_number; - h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; - } - - rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), - h_child_field_descs.data(), - num_child_fields * sizeof(field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector d_child_locations( - static_cast(num_rows) * num_child_fields, stream, mr); - scan_nested_message_fields_kernel<<>>( - message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - num_rows, - d_child_field_descs.data(), - num_child_fields, - d_child_locations.data(), - d_error.data()); - - std::vector> struct_children; - for (int ci = 0; ci < num_child_fields; ci++) { - int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; - auto const enc = schema[child_schema_idx].encoding; - bool has_def = schema[child_schema_idx].has_default_value; - bool is_repeated = schema[child_schema_idx].is_repeated; - - if (is_repeated) { - struct_children.push_back(build_repeated_child_list_column(message_data, - list_offsets, - base_offset, - d_parent_locs.data(), - num_rows, - child_schema_idx, - schema, - num_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - stream, - mr, - depth)); - continue; - } - - switch (dt.id()) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - case cudf::type_id::FLOAT32: - case cudf::type_id::FLOAT64: { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - struct_children.push_back( - extract_typed_column(dt, - enc, - message_data, - loc_provider, - num_rows, - blocks, - threads, - has_def, - has_def ? default_ints[child_schema_idx] : 0, - has_def ? default_floats[child_schema_idx] : 0.0, - has_def ? default_bools[child_schema_idx] : false, - default_strings[child_schema_idx], - child_schema_idx, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - stream, - mr)); - break; - } - case cudf::type_id::STRING: { - if (enc == spark_rapids_jni::ENC_ENUM_STRING) { - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out.data(), - valid.data(), - d_error.data(), - has_def, - def_int); - - if (child_schema_idx < static_cast(enum_valid_values.size()) && - child_schema_idx < static_cast(enum_names.size())) { - auto const& valid_enums = enum_valid_values[child_schema_idx]; - auto const& enum_name_bytes = enum_names[child_schema_idx]; - if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { - struct_children.push_back(build_enum_string_column(out, - valid, - valid_enums, - enum_name_bytes, - d_row_has_invalid_enum, - num_rows, - stream, - mr)); - } else { - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); - struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); - } - } else { - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); - struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); - } - } else { - bool has_def_str = has_def; - auto const& def_str = default_strings[child_schema_idx]; - NestedLocationProvider len_provider{ - nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; - NestedLocationProvider copy_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - auto valid_fn = [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, - num_child_fields, - has_def_str] __device__(cudf::size_type row) { - return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || - has_def_str; - }; - struct_children.push_back(extract_and_build_string_or_bytes_column(false, - message_data, - num_rows, - len_provider, - copy_provider, - valid_fn, - has_def_str, - def_str, - d_error, - stream, - mr)); - } - break; - } - case cudf::type_id::LIST: { - // bytes (BinaryType) represented as LIST - bool has_def_bytes = has_def; - auto const& def_bytes = default_strings[child_schema_idx]; - NestedLocationProvider len_provider{ - nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; - NestedLocationProvider copy_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - auto valid_fn = [plocs = d_parent_locs.data(), - flocs = d_child_locations.data(), - ci, - num_child_fields, - has_def_bytes] __device__(cudf::size_type row) { - return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || - has_def_bytes; - }; - struct_children.push_back(extract_and_build_string_or_bytes_column(true, - message_data, - num_rows, - len_provider, - copy_provider, - valid_fn, - has_def_bytes, - def_bytes, - d_error, - stream, - mr)); - break; - } - case cudf::type_id::STRUCT: { - auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); - if (gc_indices.empty()) { - struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); - break; - } - rmm::device_uvector d_gc_parent(num_rows, stream, mr); - compute_grandchild_parent_locations_kernel<<>>( - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - d_gc_parent.data(), - num_rows); - struct_children.push_back(build_nested_struct_column(message_data, - list_offsets, - base_offset, - d_gc_parent, - gc_indices, - schema, - num_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - num_rows, - stream, - mr, - depth + 1)); - break; - } - default: struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); break; - } - } - - rmm::device_uvector struct_valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - struct_valid.data(), - [plocs = d_parent_locs.data()] __device__(auto row) { return plocs[row].offset >= 0; }); - auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); - return cudf::make_structs_column( - num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); -} - -/** - * Build a LIST column for a repeated child field inside a parent message. - * Shared between build_nested_struct_column and build_repeated_struct_column. - */ -std::unique_ptr build_repeated_child_list_column( - uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_parent_rows, - int child_schema_idx, - std::vector const& schema, - int num_fields, - std::vector const& schema_output_types, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int depth) -{ - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_parent_rows + threads - 1) / threads); - - auto elem_type_id = schema[child_schema_idx].output_type; - rmm::device_uvector d_rep_info(num_parent_rows, stream, mr); - - std::vector rep_indices = {0}; - rmm::device_uvector d_rep_indices(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_rep_indices.data(), rep_indices.data(), sizeof(int), cudaMemcpyHostToDevice, stream.value())); - - device_nested_field_descriptor rep_desc; - rep_desc.field_number = schema[child_schema_idx].field_number; - rep_desc.wire_type = schema[child_schema_idx].wire_type; - rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); - rep_desc.is_repeated = true; - rep_desc.parent_idx = -1; - rep_desc.depth = 0; - rep_desc.encoding = 0; - rep_desc.is_required = false; - rep_desc.has_default_value = false; - - std::vector h_rep_schema = {rep_desc}; - rmm::device_uvector d_rep_schema(1, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), - h_rep_schema.data(), - sizeof(device_nested_field_descriptor), - cudaMemcpyHostToDevice, - stream.value())); - - count_repeated_in_nested_kernel<<>>(message_data, - row_offsets, - base_offset, - parent_locs, - num_parent_rows, - d_rep_schema.data(), - 1, - d_rep_info.data(), - 1, - d_rep_indices.data(), - d_error.data()); - - rmm::device_uvector d_rep_counts(num_parent_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_info.data(), - d_rep_info.end(), - d_rep_counts.data(), - [] __device__(repeated_field_info const& info) { return info.count; }); - int total_rep_count = - thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); - - if (total_rep_count == 0) { - rmm::device_uvector list_offsets_vec(num_parent_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); - auto list_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_parent_rows + 1, - list_offsets_vec.release(), - rmm::device_buffer{}, - 0); - std::unique_ptr child_col; - if (elem_type_id == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_fields, stream, mr); - } else { - child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - return cudf::make_lists_column(num_parent_rows, - std::move(list_offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); - } - - rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_parent_rows, - &total_rep_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>(message_data, - row_offsets, - base_offset, - parent_locs, - num_parent_rows, - d_rep_schema.data(), - 1, - list_offs.data(), - 1, - d_rep_indices.data(), - d_rep_occs.data(), - d_error.data()); - - std::unique_ptr child_values; - if (elem_type_id == cudf::type_id::INT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::INT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::INT64}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::BOOL8) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - extract_varint_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::BOOL8}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::FLOAT32) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - extract_fixed_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT32}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::FLOAT64) { - rmm::device_uvector values(total_rep_count, stream, mr); - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - extract_fixed_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>( - message_data, loc_provider, total_rep_count, values.data(), nullptr, d_error.data()); - child_values = std::make_unique(cudf::data_type{cudf::type_id::FLOAT64}, - total_rep_count, - values.release(), - rmm::device_buffer{}, - 0); - } else if (elem_type_id == cudf::type_id::STRING) { - rmm::device_uvector d_str_lengths(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.data(), - d_rep_occs.end(), - d_str_lengths.data(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - int32_t total_chars = - thrust::reduce(rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), 0); - rmm::device_uvector str_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_str_lengths.data(), d_str_lengths.end(), str_offs.data(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(str_offs.data() + total_rep_count, - &total_chars, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - copy_varlen_data_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - loc_provider, - total_rep_count, - str_offs.data(), - chars.data(), - d_error.data()); - } - - auto str_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count + 1, - str_offs.release(), - rmm::device_buffer{}, - 0); - child_values = cudf::make_strings_column( - total_rep_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); - } else if (elem_type_id == cudf::type_id::LIST) { - rmm::device_uvector d_len(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), - d_rep_occs.data(), - d_rep_occs.end(), - d_len.data(), - [] __device__(repeated_occurrence const& occ) { return occ.length; }); - - int32_t total_bytes = thrust::reduce(rmm::exec_policy(stream), d_len.data(), d_len.end(), 0); - rmm::device_uvector byte_offs(total_rep_count + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_len.data(), d_len.end(), byte_offs.data(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(byte_offs.data() + total_rep_count, - &total_bytes, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector bytes(total_bytes, stream, mr); - if (total_bytes > 0) { - NestedRepeatedLocationProvider loc_provider{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; - copy_varlen_data_kernel - <<<(total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK, - THREADS_PER_BLOCK, - 0, - stream.value()>>>(message_data, - loc_provider, - total_rep_count, - byte_offs.data(), - bytes.data(), - d_error.data()); - } - - auto offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_rep_count + 1, - byte_offs.release(), - rmm::device_buffer{}, - 0); - auto bytes_child = - std::make_unique(cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - rmm::device_buffer(bytes.data(), total_bytes, stream, mr), - rmm::device_buffer{}, - 0); - child_values = cudf::make_lists_column(total_rep_count, - std::move(offs_col), - std::move(bytes_child), - 0, - rmm::device_buffer{}, - stream, - mr); - } else if (elem_type_id == cudf::type_id::STRUCT) { - auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); - if (gc_indices.empty()) { - child_values = cudf::make_structs_column(total_rep_count, - std::vector>{}, - 0, - rmm::device_buffer{}, - stream, - mr); - } else { - rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); - rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); - auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - compute_virtual_parents_for_nested_repeated_kernel<<>>( - d_rep_occs.data(), - row_offsets, - parent_locs, - d_virtual_row_offsets.data(), - d_virtual_parent_locs.data(), - total_rep_count); - - child_values = build_nested_struct_column(message_data, - d_virtual_row_offsets.data(), - base_offset, - d_virtual_parent_locs, - gc_indices, - schema, - num_fields, - schema_output_types, - default_ints, - default_floats, - default_bools, - default_strings, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - total_rep_count, - stream, - mr, - depth + 1); - } - } else { - child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); - } - - auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_parent_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - return cudf::make_lists_column(num_parent_rows, - std::move(list_offs_col), - std::move(child_values), - 0, - rmm::device_buffer{}, - stream, - mr); -} - -} // anonymous namespace - std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, ProtobufDecodeContext const& context, rmm::cuda_stream_view stream) @@ -4927,13 +571,13 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_uvector enum_ints(total_count, stream, mr); auto const rep_blocks = static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); - extract_repeated_varint_kernel + RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; + extract_varint_kernel <<>>(message_data, - list_offsets, - base_offset, - d_occurrences.data(), + rep_loc, total_count, enum_ints.data(), + nullptr, d_error.data()); // 2. Build device-side enum lookup tables @@ -4995,10 +639,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& total_count); // 5. Build string offsets - rmm::device_uvector str_offsets(total_count + 1, stream, mr); - build_offsets_from_lengths(elem_lengths, str_offsets, stream); - int32_t total_chars = thrust::reduce( - rmm::exec_policy(stream), elem_lengths.begin(), elem_lengths.end(), 0); + auto [str_offs_col, total_chars] = + cudf::strings::detail::make_offsets_child_column( + elem_lengths.begin(), elem_lengths.end(), stream, mr); // 6. Copy string chars rmm::device_uvector chars(total_chars, stream, mr); @@ -5010,18 +653,12 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_name_offsets.data(), d_name_chars.data(), static_cast(valid_enums.size()), - str_offsets.data(), + str_offs_col->view().data(), chars.data(), total_count); } // 7. Assemble LIST column - auto str_offs_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - total_count + 1, - str_offsets.release(), - rmm::device_buffer{}, - 0); auto child_col = cudf::make_strings_column( total_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); @@ -5286,6 +923,31 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& struct_null_count = null_count; } + // cuDF struct child views do not inherit parent nulls. Push PERMISSIVE invalid-enum nulls + // down into every top-level child so extracted fields respect "null struct => null field". + if (has_enum_fields && struct_null_count > 0) { + auto const* struct_mask_ptr = static_cast(struct_mask.data()); + for (auto& child : top_level_children) { + auto child_view = child->mutable_view(); + if (child_view.nullable()) { + auto const child_mask_words = + cudf::num_bitmask_words(static_cast(child_view.size() + child_view.offset())); + std::array masks{child_view.null_mask(), struct_mask_ptr}; + std::array begin_bits{child_view.offset(), 0}; + auto const valid_count = cudf::detail::inplace_bitmask_and( + cudf::device_span(child_view.null_mask(), child_mask_words), + cudf::host_span(masks.data(), masks.size()), + cudf::host_span(begin_bits.data(), begin_bits.size()), + child_view.size(), + stream); + child->set_null_count(child_view.size() - valid_count); + } else { + auto child_mask = cudf::detail::copy_bitmask(struct_mask_ptr, 0, num_rows, stream, mr); + child->set_null_mask(std::move(child_mask), struct_null_count); + } + } + } + return cudf::make_structs_column( num_rows, std::move(top_level_children), struct_null_count, std::move(struct_mask), stream, mr); } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu new file mode 100644 index 0000000000..a06b3b7b1d --- /dev/null +++ b/src/main/cpp/src/protobuf_builders.cu @@ -0,0 +1,1407 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "protobuf_common.cuh" + +#include + +namespace spark_rapids_jni::protobuf_detail { + +/** + * Helper to build string or bytes column for repeated message child fields. + * When as_bytes=false, builds a STRING column. When as_bytes=true, builds LIST. + * Uses GPU kernels for parallel extraction (critical performance fix!). + */ +inline std::unique_ptr build_repeated_msg_child_varlen_column( + uint8_t const* message_data, + rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_locs, + rmm::device_uvector const& d_child_locs, + int child_idx, + int num_child_fields, + int total_count, + rmm::device_uvector& d_error, + bool as_bytes, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (total_count == 0) { + if (as_bytes) { + return make_empty_column_safe(cudf::data_type{cudf::type_id::LIST}, stream, mr); + } + return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + } + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (total_count + threads - 1) / threads; + + rmm::device_uvector d_lengths(total_count, stream, mr); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_lengths.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + auto const& loc = child_locs[idx * ncf + ci]; + return loc.offset >= 0 ? loc.length : 0; + }); + + auto [offsets_col, total_data] = cudf::strings::detail::make_offsets_child_column( + d_lengths.begin(), d_lengths.end(), stream, mr); + + rmm::device_uvector d_data(total_data, stream, mr); + rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); + + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + return child_locs[idx * ncf + ci].offset >= 0; + }); + + if (total_data > 0) { + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + total_count, + offsets_col->view().data(), + d_data.data(), + d_error.data()); + } + + auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); + + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_data, + rmm::device_buffer(d_data.data(), total_data, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column(total_count, + std::move(offsets_col), + std::move(bytes_child), + null_count, + std::move(mask), + stream, + mr); + } + + return cudf::make_strings_column( + total_count, std::move(offsets_col), d_data.release(), null_count, std::move(mask)); +} + +// ============================================================================ +// Utility functions +// ============================================================================ + +// Note: make_null_mask_from_valid is defined earlier in the file (before +// scan_repeated_message_children_kernel) + +/** + * Create an all-null column of the specified type. + */ +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0) { return cudf::make_empty_column(dtype); } + + switch (dtype.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT8: + case cudf::type_id::UINT8: + case cudf::type_id::INT16: + case cudf::type_id::UINT16: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: + return cudf::make_fixed_width_column(dtype, num_rows, cudf::mask_state::ALL_NULL, stream, mr); + case cudf::type_id::STRING: { + rmm::device_uvector pairs(num_rows, stream, mr); + thrust::fill(rmm::exec_policy(stream), + pairs.data(), + pairs.end(), + cudf::strings::detail::string_index_pair{nullptr, 0}); + return cudf::strings::detail::make_strings_column(pairs.data(), pairs.end(), stream, mr); + } + case cudf::type_id::LIST: + return cudf::lists::detail::make_all_nulls_lists_column( + num_rows, cudf::data_type{cudf::type_id::UINT8}, stream, mr); + case cudf::type_id::STRUCT: { + std::vector> empty_children; + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(empty_children), num_rows, std::move(null_mask), stream, mr); + } + default: CUDF_FAIL("Unsupported type for null column creation"); + } +} + +/** + * Create an empty column (0 rows) of the specified type. + * This handles nested types (LIST, STRUCT) that cudf::make_empty_column doesn't support. + */ +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (dtype.id()) { + case cudf::type_id::LIST: { + // Create empty list column with empty UINT8 child (Spark BinaryType maps to LIST) + auto offsets_col = + std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + // Initialize offset to 0 + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + } + case cudf::type_id::STRUCT: { + // Create empty struct column with no children + std::vector> empty_children; + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + default: + // For non-nested types, use cudf's make_empty_column + return cudf::make_empty_column(dtype); + } +} + +/** + * Create an all-null LIST column with the provided child column. + * The child column is expected to have 0 rows. + */ +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + num_rows, + std::move(null_mask), + stream, + mr); +} + +/** + * Wrap a 0-row element column into a 0-row LIST column. + */ +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + 1, + rmm::device_buffer(sizeof(int32_t), stream, mr), + rmm::device_buffer{}, + 0); + int32_t zero = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), + &zero, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + return cudf::make_lists_column( + 0, std::move(offsets_col), std::move(element_col), 0, rmm::device_buffer{}, stream, mr); +} + + +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1) / threads); + + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + validate_enum_values_kernel<<>>( + enum_values.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows); + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int32_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + h_name_offsets[k + 1] = total_name_chars; + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); + cursor += static_cast(name.size()); + } + } + + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + compute_enum_string_lengths_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + lengths.data(), + num_rows); + + auto [offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( + lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + offsets_col->view().data(), + chars.data(), + num_rows); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + // All rows have count=0, but we still need to check input nulls + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto child_col = is_bytes ? make_empty_column_safe( + cudf::data_type{cudf::type_id::LIST}, stream, mr) // LIST + : cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); + } else { + // No input nulls, all rows get empty arrays [] + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Extract string lengths from occurrences + rmm::device_uvector str_lengths(total_count, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (total_count + threads - 1) / threads; + RepeatedLocationProvider loc_provider{nullptr, 0, d_occurrences.data()}; + extract_lengths_kernel + <<>>(loc_provider, total_count, str_lengths.data()); + + auto [str_offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( + str_lengths.begin(), str_lengths.end(), stream, mr); + + rmm::device_uvector chars(total_chars, stream, mr); + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + if (total_chars > 0) { + RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + copy_varlen_data_kernel<<>>( + message_data, + loc_provider, + total_count, + str_offsets_col->view().data(), + chars.data(), + d_error.data()); + } + + std::unique_ptr child_col; + if (is_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_chars, + rmm::device_buffer(chars.data(), total_chars, stream, mr), + rmm::device_buffer{}, + 0); + child_col = cudf::make_lists_column(total_count, + std::move(str_offsets_col), + std::move(bytes_child), + 0, + rmm::device_buffer{}, + stream, + mr); + } else { + child_col = cudf::make_strings_column( + total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); + } + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); +} + +// Forward declaration -- build_nested_struct_column is defined after build_repeated_struct_column +// but the latter's STRUCT-child case needs to call it. +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth); + +// Forward declaration -- build_repeated_child_list_column is defined after +// build_nested_struct_column but both build_repeated_struct_column and build_nested_struct_column +// need to call it. +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth); + +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& h_device_schema, + std::vector const& child_field_indices, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error_top, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + int num_child_fields = static_cast(child_field_indices.size()); + + if (total_count == 0 || num_child_fields == 0) { + // All rows have count=0 or no child fields - return list of empty structs + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + + // Build empty struct child column with proper nested structure + int num_schema_fields = static_cast(h_device_schema.size()); + std::vector> empty_struct_children; + for (int child_schema_idx : child_field_indices) { + auto child_type = schema_output_types[child_schema_idx]; + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + h_device_schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + if (h_device_schema[child_schema_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + empty_struct_children.push_back(std::move(child_col)); + } + auto empty_struct = cudf::make_structs_column( + 0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(empty_struct), + input_null_count, + std::move(null_mask), + stream, + mr); + } else { + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(empty_struct), + 0, + rmm::device_buffer{}, + stream, + mr); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + // Build child field descriptors for scanning within each message occurrence + std::vector h_child_descs(num_child_fields); + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; + h_child_descs[ci].expected_wire_type = h_device_schema[child_schema_idx].wire_type; + } + rmm::device_uvector d_child_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), + h_child_descs.data(), + num_child_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + // For each occurrence, we need to scan for child fields + // Create "virtual" parent locations from the occurrences using GPU kernel + // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) + rmm::device_uvector d_msg_locs(total_count, stream, mr); + rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); + rmm::device_uvector d_msg_row_offsets_size(total_count, stream, mr); + { + auto const occ_threads = THREADS_PER_BLOCK; + auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; + compute_msg_locations_from_occurrences_kernel<<>>( + d_occurrences.data(), + list_offsets, + base_offset, + d_msg_locs.data(), + d_msg_row_offsets.data(), + total_count); + } + thrust::transform(rmm::exec_policy(stream), + d_msg_row_offsets.data(), + d_msg_row_offsets.end(), + d_msg_row_offsets_size.data(), + [] __device__(int32_t v) { return static_cast(v); }); + + // Scan for child fields within each message occurrence + rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); + // Reuse top-level error flag so failfast can observe nested repeated-message failures. + auto& d_error = d_error_top; + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (total_count + threads - 1) / threads; + + // Use a custom kernel to scan child fields within message occurrences + // This is similar to scan_nested_message_fields_kernel but operates on occurrences + scan_repeated_message_children_kernel<<>>( + message_data, + d_msg_row_offsets.data(), + d_msg_locs.data(), + total_count, + d_child_descs.data(), + num_child_fields, + d_child_locs.data(), + d_error.data()); + + // Note: We no longer need to copy child_locs to host because: + // 1. All scalar extraction kernels access d_child_locs directly on device + // 2. String extraction uses GPU kernels + // 3. Nested struct locations are computed on GPU via compute_nested_struct_locations_kernel + + // Extract child field values - build one column per child field + std::vector> struct_children; + int num_schema_fields = static_cast(h_device_schema.size()); + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = h_device_schema[child_schema_idx].encoding; + bool has_def = h_device_schema[child_schema_idx].has_default_value; + bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; + + if (child_is_repeated) { + struct_children.push_back(build_repeated_child_list_column(message_data, + d_msg_row_offsets_size.data(), + 0, + d_msg_locs.data(), + total_count, + child_schema_idx, + schema, + num_schema_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error_top, + stream, + mr, + 1)); + continue; + } + + switch (dt.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: { + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + struct_children.push_back( + extract_typed_column(dt, + enc, + message_data, + loc_provider, + total_count, + blocks, + threads, + has_def, + has_def ? default_ints[child_schema_idx] : 0, + has_def ? default_floats[child_schema_idx] : 0.0, + has_def ? default_bools[child_schema_idx] : false, + default_strings[child_schema_idx], + child_schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr)); + break; + } + case cudf::type_id::STRING: { + struct_children.push_back(build_repeated_msg_child_varlen_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + false, + stream, + mr)); + break; + } + case cudf::type_id::LIST: { + struct_children.push_back(build_repeated_msg_child_varlen_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + true, + stream, + mr)); + break; + } + case cudf::type_id::STRUCT: { + // Nested struct inside repeated message - use recursive build_nested_struct_column + int num_schema_fields = static_cast(h_device_schema.size()); + auto grandchild_indices = + find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); + + if (grandchild_indices.empty()) { + struct_children.push_back( + cudf::make_structs_column(total_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr)); + } else { + // Compute virtual parent locations for each occurrence's nested struct child + rmm::device_uvector d_nested_locs(total_count, stream, mr); + rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); + { + // Convert int32_t row offsets to cudf::size_type and compute nested struct locations + rmm::device_uvector d_nested_row_offsets_i32(total_count, stream, mr); + compute_nested_struct_locations_kernel<<>>( + d_child_locs.data(), + d_msg_locs.data(), + d_msg_row_offsets.data(), + ci, + num_child_fields, + d_nested_locs.data(), + d_nested_row_offsets_i32.data(), + total_count); + // Add base_offset back so build_nested_struct_column can subtract it + thrust::transform(rmm::exec_policy(stream), + d_nested_row_offsets_i32.data(), + d_nested_row_offsets_i32.end(), + d_nested_row_offsets.data(), + [base_offset] __device__(int32_t v) { + return static_cast(v) + base_offset; + }); + } + + struct_children.push_back(build_nested_struct_column(message_data, + d_nested_row_offsets.data(), + base_offset, + d_nested_locs, + grandchild_indices, + schema, + num_schema_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error_top, + total_count, + stream, + mr, + 0)); + } + break; + } + default: + // Unsupported child type - create null column + struct_children.push_back(make_null_column(dt, total_count, stream, mr)); + break; + } + } + + // Build the struct column from child columns + auto struct_col = cudf::make_structs_column( + total_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); + + // Build the list offsets column + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + + // Build the final LIST column + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(struct_col), + input_null_count, + std::move(null_mask), + stream, + mr); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); +} + +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth) +{ + CUDF_EXPECTS(depth < MAX_NESTED_STRUCT_DECODE_DEPTH, + "Nested protobuf struct depth exceeds supported decode recursion limit"); + + if (num_rows == 0) { + std::vector> empty_children; + for (int child_schema_idx : child_field_indices) { + auto child_type = schema_output_types[child_schema_idx]; + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + if (schema[child_schema_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + empty_children.push_back(std::move(child_col)); + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1) / threads); + int num_child_fields = static_cast(child_field_indices.size()); + + std::vector h_child_field_descs(num_child_fields); + for (int i = 0; i < num_child_fields; i++) { + int child_idx = child_field_indices[i]; + h_child_field_descs[i].field_number = schema[child_idx].field_number; + h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; + } + + rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), + h_child_field_descs.data(), + num_child_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_child_locations( + static_cast(num_rows) * num_child_fields, stream, mr); + scan_nested_message_fields_kernel<<>>( + message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + d_child_field_descs.data(), + num_child_fields, + d_child_locations.data(), + d_error.data()); + + std::vector> struct_children; + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = schema_output_types[child_schema_idx]; + auto const enc = schema[child_schema_idx].encoding; + bool has_def = schema[child_schema_idx].has_default_value; + bool is_repeated = schema[child_schema_idx].is_repeated; + + if (is_repeated) { + struct_children.push_back(build_repeated_child_list_column(message_data, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + child_schema_idx, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr, + depth)); + continue; + } + + switch (dt.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: { + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + struct_children.push_back( + extract_typed_column(dt, + enc, + message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + has_def ? default_ints[child_schema_idx] : 0, + has_def ? default_floats[child_schema_idx] : 0.0, + has_def ? default_bools[child_schema_idx] : false, + default_strings[child_schema_idx], + child_schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr)); + break; + } + case cudf::type_id::STRING: { + if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + NestedLocationProvider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out.data(), + valid.data(), + d_error.data(), + has_def, + def_int); + + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size())) { + auto const& valid_enums = enum_valid_values[child_schema_idx]; + auto const& enum_name_bytes = enum_names[child_schema_idx]; + if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { + struct_children.push_back(build_enum_string_column(out, + valid, + valid_enums, + enum_name_bytes, + d_row_has_invalid_enum, + num_rows, + stream, + mr)); + } else { + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + } + } else { + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + } + } else { + bool has_def_str = has_def; + auto const& def_str = default_strings[child_schema_idx]; + NestedLocationProvider len_provider{ + nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; + NestedLocationProvider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + auto valid_fn = [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_str] __device__(cudf::size_type row) { + return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + has_def_str; + }; + struct_children.push_back(extract_and_build_string_or_bytes_column(false, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_str, + def_str, + d_error, + stream, + mr)); + } + break; + } + case cudf::type_id::LIST: { + // bytes (BinaryType) represented as LIST + bool has_def_bytes = has_def; + auto const& def_bytes = default_strings[child_schema_idx]; + NestedLocationProvider len_provider{ + nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; + NestedLocationProvider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + auto valid_fn = [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_bytes] __device__(cudf::size_type row) { + return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + has_def_bytes; + }; + struct_children.push_back(extract_and_build_string_or_bytes_column(true, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_bytes, + def_bytes, + d_error, + stream, + mr)); + break; + } + case cudf::type_id::STRUCT: { + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + break; + } + rmm::device_uvector d_gc_parent(num_rows, stream, mr); + compute_grandchild_parent_locations_kernel<<>>( + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + d_gc_parent.data(), + num_rows); + struct_children.push_back(build_nested_struct_column(message_data, + list_offsets, + base_offset, + d_gc_parent, + gc_indices, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + num_rows, + stream, + mr, + depth + 1)); + break; + } + default: struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); break; + } + } + + rmm::device_uvector struct_valid((num_rows > 0 ? num_rows : 1), stream, mr); + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + struct_valid.data(), + [plocs = d_parent_locs.data()] __device__(auto row) { return plocs[row].offset >= 0; }); + auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); +} + +/** + * Build a LIST column for a repeated child field inside a parent message. + * Shared between build_nested_struct_column and build_repeated_struct_column. + */ +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_parent_rows + threads - 1) / threads); + + auto elem_type_id = schema[child_schema_idx].output_type; + rmm::device_uvector d_rep_info(num_parent_rows, stream, mr); + + std::vector rep_indices = {0}; + rmm::device_uvector d_rep_indices(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_rep_indices.data(), rep_indices.data(), sizeof(int), cudaMemcpyHostToDevice, stream.value())); + + device_nested_field_descriptor rep_desc; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + rep_desc.parent_idx = -1; + rep_desc.depth = 0; + rep_desc.encoding = 0; + rep_desc.is_required = false; + rep_desc.has_default_value = false; + + std::vector h_rep_schema = {rep_desc}; + rmm::device_uvector d_rep_schema(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), + h_rep_schema.data(), + sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + count_repeated_in_nested_kernel<<>>(message_data, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + 1, + d_rep_info.data(), + 1, + d_rep_indices.data(), + d_error.data()); + + rmm::device_uvector d_rep_counts(num_parent_rows, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_info.data(), + d_rep_info.end(), + d_rep_counts.data(), + [] __device__(repeated_field_info const& info) { return info.count; }); + int total_rep_count = + thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); + + if (total_rep_count == 0) { + rmm::device_uvector list_offsets_vec(num_parent_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); + auto list_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offsets_vec.release(), + rmm::device_buffer{}, + 0); + std::unique_ptr child_col; + if (elem_type_id == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + return cudf::make_lists_column(num_parent_rows, + std::move(list_offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); + } + + rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_parent_rows, + &total_rep_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + scan_repeated_in_nested_kernel<<>>(message_data, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + 1, + list_offs.data(), + 1, + d_rep_indices.data(), + d_rep_occs.data(), + d_error.data()); + + std::unique_ptr child_values; + auto const rep_blocks = + static_cast((total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); + NestedRepeatedLocationProvider nr_loc{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + + if (elem_type_id == cudf::type_id::BOOL8 || elem_type_id == cudf::type_id::INT32 || + elem_type_id == cudf::type_id::UINT32 || elem_type_id == cudf::type_id::INT64 || + elem_type_id == cudf::type_id::UINT64 || elem_type_id == cudf::type_id::FLOAT32 || + elem_type_id == cudf::type_id::FLOAT64) { + child_values = + extract_typed_column(cudf::data_type{elem_type_id}, + schema[child_schema_idx].encoding, + message_data, + nr_loc, + total_rep_count, + rep_blocks, + THREADS_PER_BLOCK, + false, + 0, + 0.0, + false, + std::vector{}, + child_schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr); + } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { + bool as_bytes = (elem_type_id == cudf::type_id::LIST); + auto valid_fn = [] __device__(cudf::size_type) { return true; }; + std::vector empty_default; + child_values = extract_and_build_string_or_bytes_column(as_bytes, + message_data, + total_rep_count, + nr_loc, + nr_loc, + valid_fn, + false, + empty_default, + d_error, + stream, + mr); + } else if (elem_type_id == cudf::type_id::STRUCT) { + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + child_values = cudf::make_structs_column(total_rep_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr); + } else { + rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); + rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); + auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + compute_virtual_parents_for_nested_repeated_kernel<<>>( + d_rep_occs.data(), + row_offsets, + parent_locs, + d_virtual_row_offsets.data(), + d_virtual_parent_locs.data(), + total_rep_count); + + child_values = build_nested_struct_column(message_data, + d_virtual_row_offsets.data(), + base_offset, + d_virtual_parent_locs, + gc_indices, + schema, + num_fields, + schema_output_types, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + total_rep_count, + stream, + mr, + depth + 1); + } + } else { + child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column(num_parent_rows, + std::move(list_offs_col), + std::move(child_values), + 0, + rmm::device_buffer{}, + stream, + mr); +} + +} // namespace spark_rapids_jni::protobuf_detail diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh new file mode 100644 index 0000000000..4a625581b8 --- /dev/null +++ b/src/main/cpp/src/protobuf_common.cuh @@ -0,0 +1,1375 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace spark_rapids_jni::protobuf_detail { + +// Wire type constants (protobuf encoding spec) +constexpr int WT_VARINT = 0; +constexpr int WT_64BIT = 1; +constexpr int WT_LEN = 2; +constexpr int WT_SGROUP = 3; +constexpr int WT_EGROUP = 4; +constexpr int WT_32BIT = 5; + +// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. +constexpr int MAX_VARINT_BYTES = 10; + +// CUDA kernel launch configuration. +constexpr int THREADS_PER_BLOCK = 256; + +// Error codes for kernel error reporting. +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; + +// Maximum supported nesting depth for recursive struct decoding. +constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; + +// Threshold for using a direct-mapped lookup table for field_number -> field_index. +// Field numbers above this threshold fall back to linear search. +constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; + +/** + * Structure to record field location within a message. + * offset < 0 means field was not found. + */ +struct field_location { + int32_t offset; // Offset of field data within the message (-1 if not found) + int32_t length; // Length of field data in bytes +}; + +/** + * Field descriptor passed to the scanning kernel. + */ +struct field_descriptor { + int field_number; // Protobuf field number + int expected_wire_type; // Expected wire type for this field +}; + +/** + * Information about repeated field occurrences in a row. + */ +struct repeated_field_info { + int32_t count; // Number of occurrences in this row + int32_t total_length; // Total bytes for all occurrences (for varlen fields) +}; + +/** + * Location of a single occurrence of a repeated field. + */ +struct repeated_occurrence { + int32_t row_idx; // Which row this occurrence belongs to + int32_t offset; // Offset within the message + int32_t length; // Length of the field data +}; + +/** + * Device-side descriptor for nested schema fields. + */ +struct device_nested_field_descriptor { + int field_number; + int parent_idx; + int depth; + int wire_type; + int output_type_id; + int encoding; + bool is_repeated; + bool is_required; + bool has_default_value; + + device_nested_field_descriptor() = default; + + explicit device_nested_field_descriptor(spark_rapids_jni::nested_field_descriptor const& src) + : field_number(src.field_number), + parent_idx(src.parent_idx), + depth(src.depth), + wire_type(src.wire_type), + output_type_id(static_cast(src.output_type)), + encoding(src.encoding), + is_repeated(src.is_repeated), + is_required(src.is_required), + has_default_value(src.has_default_value) + { + } +}; + +// ============================================================================ +// Device helper functions +// ============================================================================ + +__device__ inline bool read_varint(uint8_t const* cur, + uint8_t const* end, + uint64_t& out, + int& bytes) +{ + out = 0; + bytes = 0; + int shift = 0; + // Protobuf varint uses 7 bits per byte with MSB as continuation flag. + // A 64-bit value requires at most ceil(64/7) = 10 bytes. + while (cur < end && bytes < MAX_VARINT_BYTES) { + uint8_t b = *cur++; + // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid + if (bytes == 9 && (b & 0xFE) != 0) { + return false; // Invalid: 10th byte has more than 1 significant bit + } + out |= (static_cast(b & 0x7Fu) << shift); + bytes++; + if ((b & 0x80u) == 0) { return true; } + shift += 7; + } + return false; +} + +__device__ inline void set_error_once(int* error_flag, int error_code) +{ + atomicCAS(error_flag, 0, error_code); +} + +__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) +{ + switch (wt) { + case WT_VARINT: { + // Need to scan to find the end of varint + int count = 0; + while (cur < end && count < MAX_VARINT_BYTES) { + if ((*cur++ & 0x80u) == 0) { return count + 1; } + count++; + } + return -1; // Invalid varint + } + case WT_64BIT: + // Check if there's enough data for 8 bytes + if (end - cur < 8) return -1; + return 8; + case WT_32BIT: + // Check if there's enough data for 4 bytes + if (end - cur < 4) return -1; + return 4; + case WT_LEN: { + uint64_t len; + int n; + if (!read_varint(cur, end, len, n)) return -1; + if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX - n)) + return -1; + return n + static_cast(len); + } + case WT_SGROUP: { + auto const* start = cur; + // Recursively skip until the matching end-group tag. + while (cur < end) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) return -1; + cur += key_bytes; + + int inner_wt = static_cast(key & 0x7); + if (inner_wt == WT_EGROUP) { return static_cast(cur - start); } + + int inner_size = get_wire_type_size(inner_wt, cur, end); + if (inner_size < 0 || cur + inner_size > end) return -1; + cur += inner_size; + } + return -1; + } + case WT_EGROUP: return 0; + default: return -1; + } +} + +__device__ inline bool skip_field(uint8_t const* cur, + uint8_t const* end, + int wt, + uint8_t const*& out_cur) +{ + // End-group is handled by the parent group parser. + if (wt == WT_EGROUP) { + out_cur = cur; + return true; + } + + int size = get_wire_type_size(wt, cur, end); + if (size < 0) return false; + // Ensure we don't skip past the end of the buffer + if (cur + size > end) return false; + out_cur = cur + size; + return true; +} + +/** + * Get the data offset and length for a field at current position. + * Returns true on success, false on error. + */ +__device__ inline bool get_field_data_location( + uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) +{ + if (wt == WT_LEN) { + // For length-delimited, read the length prefix + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return false; + if (len > static_cast(end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + return false; + } + data_offset = len_bytes; // offset past the length prefix + data_length = static_cast(len); + } else { + // For fixed-size and varint fields + int field_size = get_wire_type_size(wt, cur, end); + if (field_size < 0) return false; + data_offset = 0; + data_length = field_size; + } + return true; +} + +__device__ inline bool check_message_bounds(int32_t start, + int32_t end_pos, + cudf::size_type total_size, + int* error_flag) +{ + if (start < 0 || end_pos < start || end_pos > total_size) { + set_error_once(error_flag, ERR_BOUNDS); + return false; + } + return true; +} + +struct proto_tag { + int field_number; + int wire_type; +}; + +__device__ inline bool decode_tag(uint8_t const*& cur, + uint8_t const* end, + proto_tag& tag, + int* error_flag) +{ + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + + cur += key_bytes; + tag.field_number = static_cast(key >> 3); + tag.wire_type = static_cast(key & 0x7); + if (tag.field_number == 0) { + set_error_once(error_flag, ERR_FIELD_NUMBER); + return false; + } + return true; +} + +/** + * Load a little-endian value from unaligned memory. + * Reads bytes individually to avoid unaligned-access issues on GPU. + */ +template +__device__ inline T load_le(uint8_t const* p); + +template <> +__device__ inline uint32_t load_le(uint8_t const* p) +{ + return static_cast(p[0]) | (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); +} + +template <> +__device__ inline uint64_t load_le(uint8_t const* p) +{ + uint64_t v = 0; +#pragma unroll + for (int i = 0; i < 8; ++i) { + v |= (static_cast(p[i]) << (8 * i)); + } + return v; +} + +// ============================================================================ +// Field number lookup table helpers +// ============================================================================ + +/** + * Build a host-side direct-mapped lookup table: field_number -> field_index. + * Returns an empty vector if the max field number exceeds the threshold. + */ +inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) +{ + int max_fn = 0; + for (int i = 0; i < num_fields; i++) { + max_fn = std::max(max_fn, descs[i].field_number); + } + if (max_fn > FIELD_LOOKUP_TABLE_MAX) return {}; + std::vector table(max_fn + 1, -1); + for (int i = 0; i < num_fields; i++) { + table[descs[i].field_number] = i; + } + return table; +} + +/** + * O(1) lookup of field_number -> field_index using a direct-mapped table. + * Falls back to linear search when the table is empty (field numbers too large). + */ +__device__ inline int lookup_field(int field_number, + int const* lookup_table, + int lookup_table_size, + field_descriptor const* field_descs, + int num_fields) +{ + if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { + return lookup_table[field_number]; + } + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == field_number) return f; + } + return -1; +} + +// ============================================================================ +// Pass 2: Extract data kernels +// ============================================================================ + +// ============================================================================ +// Data Extraction Location Providers +// ============================================================================ + +struct TopLevelLocationProvider { + cudf::size_type const* offsets; + cudf::size_type base_offset; + field_location const* locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto loc = locations[thread_idx * num_fields + field_idx]; + if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } + return loc; + } +}; + +struct RepeatedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct NestedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto ploc = parent_locations[thread_idx]; + auto cloc = child_locations[thread_idx * num_fields + field_idx]; + if (ploc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +struct NestedRepeatedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + auto ploc = parent_locations[occ.row_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct RepeatedMsgChildLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* msg_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto mloc = msg_locations[thread_idx]; + auto cloc = child_locations[thread_idx * num_fields + field_idx]; + if (mloc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +template +__global__ void extract_varint_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = static_cast(default_value); + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + uint8_t const* cur_end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + if (valid) valid[idx] = false; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + out[idx] = static_cast(v); + if (valid) valid[idx] = true; +} + +template +__global__ void extract_fixed_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = default_value; + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == WT_32BIT) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[idx] = value; + if (valid) valid[idx] = true; +} + +template +__global__ void extract_lengths_kernel(LocationProvider loc_provider, + int total_items, + int32_t* out_lengths, + bool has_default = false, + int32_t default_length = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset >= 0) { + out_lengths[idx] = loc.length; + } else if (has_default) { + out_lengths[idx] = default_length; + } else { + out_lengths[idx] = 0; + } +} +template +__global__ void copy_varlen_data_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + cudf::size_type const* output_offsets, + char* output_chars, + int* error_flag, + bool has_default = false, + uint8_t const* default_chars = nullptr, + int default_len = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + auto out_start = output_offsets[idx]; + + if (loc.offset < 0) { + if (has_default && default_len > 0) { + memcpy(output_chars + out_start, default_chars, default_len); + } + return; + } + + uint8_t const* src = message_data + data_offset; + memcpy(output_chars + out_start, src, loc.length); +} + +template +inline std::pair make_null_mask_from_valid( + rmm::device_uvector const& valid, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto begin = thrust::make_counting_iterator(0); + auto end = begin + valid.size(); + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { + return static_cast(ptr[i]); + }; + return cudf::detail::valid_if(begin, end, pred, stream, mr); +} + + +template +std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, + int num_rows, + LaunchFn&& launch_extract, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + launch_extract(out.data(), valid.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); +} + +template +// Shared integer extractor for INT32/INT64/UINT32/UINT64 decode paths. +inline void extract_integer_into_buffers(uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + T* out_ptr, + bool* valid_ptr, + int* error_ptr, + rmm::cuda_stream_view stream) +{ + if (enable_zigzag && encoding == spark_rapids_jni::ENC_ZIGZAG) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } else if (encoding == spark_rapids_jni::ENC_FIXED) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } else { + static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } + } else { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } +} + +template +// Builds a scalar column for integer-like protobuf fields. +std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + rmm::device_uvector& d_error, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + return extract_and_build_scalar_column( + dt, + num_rows, + [&](T* out_ptr, bool* valid_ptr) { + extract_integer_into_buffers(message_data, + loc_provider, + num_rows, + blocks, + threads, + has_default, + default_value, + encoding, + enable_zigzag, + out_ptr, + valid_ptr, + d_error.data(), + stream); + }, + stream, + mr); +} + +struct extract_strided_count { + repeated_field_info const* info; + int field_idx; + int num_fields; + + __device__ int32_t operator()(int row) const { return info[row * num_fields + field_idx].count; } +}; + +/** + * Find all child field indices for a given parent index in the schema. + * This is a commonly used pattern throughout the codebase. + * + * @param schema The schema vector (either nested_field_descriptor or + * device_nested_field_descriptor) + * @param num_fields Number of fields in the schema + * @param parent_idx The parent index to search for + * @return Vector of child field indices + */ +template +std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) +{ + std::vector child_indices; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } + } + return child_indices; +} + +// Forward declarations needed by make_empty_struct_column_with_schema +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +template +std::unique_ptr make_empty_struct_column_with_schema( + SchemaT const& schema, + std::vector const& schema_output_types, + int parent_idx, + int num_fields, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); + + std::vector> children; + for (int child_idx : child_indices) { + auto child_type = schema_output_types[child_idx]; + + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + schema, schema_output_types, child_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + + if (schema[child_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + + children.push_back(std::move(child_col)); + } + + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); +} + + +// ============================================================================ +// Forward declarations of non-template __global__ kernels +// ============================================================================ + +__global__ void scan_all_fields_kernel( + cudf::column_device_view const d_in, field_descriptor const* field_descs, int num_fields, + int const* field_lookup, int field_lookup_size, field_location* locations, int* error_flag); + +__global__ void count_repeated_fields_kernel( + cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, + int num_fields, int depth_level, repeated_field_info* repeated_info, int num_repeated_fields, + int const* repeated_field_indices, field_location* nested_locations, int num_nested_fields, + int const* nested_field_indices, int* error_flag); + +__global__ void scan_repeated_field_occurrences_kernel( + cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, + int schema_idx, int depth_level, int32_t const* output_offsets, + repeated_occurrence* occurrences, int* error_flag); + +__global__ void scan_nested_message_fields_kernel( + uint8_t const* message_data, cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, field_location const* parent_locations, + int num_parent_rows, field_descriptor const* field_descs, int num_fields, + field_location* output_locations, int* error_flag); + +__global__ void scan_repeated_message_children_kernel( + uint8_t const* message_data, int32_t const* msg_row_offsets, field_location const* msg_locs, + int num_occurrences, field_descriptor const* child_descs, int num_child_fields, + field_location* child_locs, int* error_flag); + +__global__ void count_repeated_in_nested_kernel( + uint8_t const* message_data, cudf::size_type const* row_offsets, cudf::size_type base_offset, + field_location const* parent_locs, int num_rows, device_nested_field_descriptor const* schema, + int num_fields, repeated_field_info* repeated_info, int num_repeated, + int const* repeated_indices, int* error_flag); + +__global__ void scan_repeated_in_nested_kernel( + uint8_t const* message_data, cudf::size_type const* row_offsets, cudf::size_type base_offset, + field_location const* parent_locs, int num_rows, device_nested_field_descriptor const* schema, + int num_fields, int32_t const* occ_prefix_sums, int num_repeated, + int const* repeated_indices, repeated_occurrence* occurrences, int* error_flag); + +__global__ void compute_nested_struct_locations_kernel( + field_location const* child_locs, field_location const* msg_locs, + int32_t const* msg_row_offsets, int child_idx, int num_child_fields, + field_location* nested_locs, int32_t* nested_row_offsets, int total_count); + +__global__ void compute_grandchild_parent_locations_kernel( + field_location const* parent_locs, field_location const* child_locs, + int child_idx, int num_child_fields, field_location* gc_parent_abs, int num_rows); + +__global__ void compute_virtual_parents_for_nested_repeated_kernel( + repeated_occurrence const* occurrences, cudf::size_type const* row_list_offsets, + field_location const* parent_locations, cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, int total_count); + +__global__ void compute_msg_locations_from_occurrences_kernel( + repeated_occurrence const* occurrences, cudf::size_type const* list_offsets, + cudf::size_type base_offset, field_location* msg_locs, int32_t* msg_row_offsets, + int total_count); + +__global__ void extract_strided_locations_kernel( + field_location const* nested_locations, int field_idx, int num_fields, + field_location* parent_locs, int num_rows); + +__global__ void check_required_fields_kernel( + field_location const* locations, uint8_t const* is_required, int num_fields, + int num_rows, int* error_flag); + +__global__ void validate_enum_values_kernel( + int32_t const* values, bool* valid, bool* row_has_invalid_enum, + int32_t const* valid_enum_values, int num_valid_values, int num_rows); + +__global__ void compute_enum_string_lengths_kernel( + int32_t const* values, bool const* valid, int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, int num_valid_values, int32_t* lengths, int num_rows); + +__global__ void copy_enum_string_chars_kernel( + int32_t const* values, bool const* valid, int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, uint8_t const* enum_name_chars, int num_valid_values, + int32_t const* output_offsets, char* out_chars, int num_rows); + +// ============================================================================ +// Forward declarations of builder/utility functions +// ============================================================================ + +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +// Complex builder forward declarations +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth); + +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int depth); + +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& h_device_schema, + std::vector const& child_field_indices, + std::vector const& schema_output_types, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error_top, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +template +inline std::unique_ptr extract_and_build_string_or_bytes_column( + bool as_bytes, + uint8_t const* message_data, + int num_rows, + LengthProvider const& length_provider, + CopyProvider const& copy_provider, + ValidityFn validity_fn, + bool has_default, + std::vector const& default_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; + rmm::device_uvector d_default(def_len, stream, mr); + if (has_default && def_len > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_default.data(), default_bytes.data(), def_len, cudaMemcpyHostToDevice, stream.value())); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (num_rows + threads - 1) / threads; + extract_lengths_kernel<<>>( + length_provider, num_rows, lengths.data(), has_default, def_len); + + auto [offsets_col, total_size] = cudf::strings::detail::make_offsets_child_column( + lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_size, stream, mr); + if (total_size > 0) { + copy_varlen_data_kernel + <<>>(message_data, + copy_provider, + num_rows, + offsets_col->view().data(), + chars.data(), + d_error.data(), + has_default, + d_default.data(), + def_len); + } + + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.data(), + validity_fn); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_size, + rmm::device_buffer(chars.data(), total_size, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(bytes_child), + null_count, + std::move(mask), + stream, + mr); + } + + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +template +inline std::unique_ptr extract_typed_column( + cudf::data_type dt, + int encoding, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_items, + int blocks, + int threads_per_block, + bool has_default, + int64_t default_int, + double default_float, + bool default_bool, + std::vector const& default_string, + int schema_idx, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (dt.id()) { + case cudf::type_id::BOOL8: { + int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_val); + }, + stream, + mr); + } + case cudf::type_id::INT32: { + rmm::device_uvector out(num_items, stream, mr); + rmm::device_uvector valid((num_items > 0 ? num_items : 1), stream, mr); + extract_integer_into_buffers(message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + has_default, + default_int, + encoding, + true, + out.data(), + valid.data(), + d_error.data(), + stream); + if (schema_idx < static_cast(enum_valid_values.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + if (!valid_enums.empty()) { + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + validate_enum_values_kernel<<>>( + out.data(), + valid.data(), + d_row_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_items); + } + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique( + dt, num_items, out.release(), std::move(mask), null_count); + } + case cudf::type_id::UINT32: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::INT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + true, + stream, + mr); + case cudf::type_id::UINT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::FLOAT32: { + float def_float_val = has_default ? static_cast(default_float) : 0.0f; + return extract_and_build_scalar_column( + dt, + num_items, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_float_val); + }, + stream, + mr); + } + case cudf::type_id::FLOAT64: { + double def_double = has_default ? default_float : 0.0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_double); + }, + stream, + mr); + } + default: return make_null_column(dt, num_items, stream, mr); + } +} + +template +inline std::unique_ptr build_repeated_scalar_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + // All rows have count=0, but we still need to check input nulls + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) + ? cudf::type_id::UINT8 + : static_cast(field_desc.output_type_id); + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); + } else { + // No input nulls, all rows get empty arrays [] + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}, + stream, + mr); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector values(total_count, stream, mr); + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = (total_count + threads - 1) / threads; + + int encoding = field_desc.encoding; + bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); + + // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) + // For integer types, use fixed kernel only if encoding is ENC_FIXED + constexpr bool is_floating_point = std::is_same_v || std::is_same_v; + bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); + + RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + if (use_fixed_kernel) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + } else if (zigzag) { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto child_col = std::make_unique( + cudf::data_type{static_cast(field_desc.output_type_id)}, + total_count, + values.release(), + rmm::device_buffer{}, + 0); + + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); +} + +} // namespace spark_rapids_jni::protobuf_detail diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu new file mode 100644 index 0000000000..a12a4c6abe --- /dev/null +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -0,0 +1,1106 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "protobuf_common.cuh" + +namespace spark_rapids_jni::protobuf_detail { + +// ============================================================================ +// Pass 1: Scan all fields kernel - records (offset, length) for each field +// ============================================================================ + +/** + * Fused scanning kernel: scans each message once and records the location + * of all requested fields. + * + * For "last one wins" semantics (protobuf standard for repeated scalars), + * we continue scanning even after finding a field. + */ +__global__ void scan_all_fields_kernel( + cudf::column_device_view const d_in, + field_descriptor const* field_descs, // [num_fields] + int num_fields, + int const* field_lookup, // direct-mapped lookup table (nullable) + int field_lookup_size, // size of lookup table (0 if null) + field_location* locations, // [num_rows * num_fields] row-major + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + for (int f = 0; f < num_fields; f++) { + locations[row * num_fields + f] = {-1, 0}; + } + + if (in.nullable() && in.is_null(row)) { return; } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + int f = lookup_field(fn, field_lookup, field_lookup_size, field_descs, num_fields); + if (f >= 0) { + if (wt != field_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + // Record the location (relative to message start) + int data_offset = static_cast(cur - bytes - start); + + if (wt == WT_LEN) { + // For length-delimited, record offset after length prefix and the data length + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + // Record offset pointing to the actual data (after length prefix) + locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; + } else { + // For fixed-size and varint fields, record offset and compute length + int field_size = get_wire_type_size(wt, cur, msg_end); + if (field_size < 0) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + locations[row * num_fields + f] = {data_offset, field_size}; + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +// ============================================================================ +// Pass 1b: Count repeated fields kernel +// ============================================================================ + +/** + * Count occurrences of repeated fields in each row. + * Also records locations of nested message fields for hierarchical processing. + * + * @note Time complexity: O(message_length * (num_repeated_fields + num_nested_fields)) per row. + */ +__global__ void count_repeated_fields_kernel( + cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, // Which depth level we're processing + repeated_field_info* repeated_info, // [num_rows * num_repeated_fields_at_this_depth] + int num_repeated_fields, // Number of repeated fields at this depth + int const* repeated_field_indices, // Indices into schema for repeated fields at this depth + field_location* + nested_locations, // Locations of nested messages for next depth [num_rows * num_nested] + int num_nested_fields, // Number of nested message fields at this depth + int const* nested_field_indices, // Indices into schema for nested message fields + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + // Initialize repeated counts to 0 + for (int f = 0; f < num_repeated_fields; f++) { + repeated_info[row * num_repeated_fields + f] = {0, 0}; + } + + // Initialize nested locations to not found + for (int f = 0; f < num_nested_fields; f++) { + nested_locations[row * num_nested_fields + f] = {-1, 0}; + } + + if (in.nullable() && in.is_null(row)) { return; } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + // Check repeated fields at this depth + for (int i = 0; i < num_repeated_fields; i++) { + int schema_idx = repeated_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + int expected_wt = schema[schema_idx].wire_type; + + // Handle both packed and unpacked encoding for repeated fields + // Packed encoding uses wire type LEN (2) even for scalar types + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + if (is_packed) { + // Packed encoding: read length, then count elements inside + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + + // Count elements based on type + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + + int count = 0; + if (expected_wt == WT_VARINT) { + // Count varints in the packed data + uint8_t const* p = packed_start; + while (p < packed_end) { + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + p += vbytes; + count++; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 4); + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 8); + } + + repeated_info[row * num_repeated_fields + i].count += count; + repeated_info[row * num_repeated_fields + i].total_length += + static_cast(packed_len); + } else { + // Non-packed encoding: single element + int32_t data_offset, data_length; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + + repeated_info[row * num_repeated_fields + i].count++; + repeated_info[row * num_repeated_fields + i].total_length += data_length; + } + } + } + + // Check nested message fields at this depth (last one wins for non-repeated) + for (int i = 0; i < num_nested_fields; i++) { + int schema_idx = nested_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + if (wt != WT_LEN) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + + int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; + nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Scan and record all occurrences of repeated fields. + * Called after count_repeated_fields_kernel to fill in actual locations. + * + * @note Time complexity: O(message_length * num_repeated_fields) per row. + */ +__global__ void scan_repeated_field_occurrences_kernel( + cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int schema_idx, // Which field in schema we're scanning + int depth_level, + int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] + repeated_occurrence* occurrences, // Output: all occurrences [total_count] + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + if (in.nullable() && in.is_null(row)) { return; } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + int target_fn = schema[schema_idx].field_number; + int target_wt = schema[schema_idx].wire_type; + int write_idx = output_offsets[row]; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + if (fn == target_fn) { + // Check for packed encoding: wire type LEN but expected non-LEN + bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); + + if (is_packed) { + // Packed encoding: multiple elements in a length-delimited blob + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + + // Record each element in the packed blob + if (target_wt == WT_VARINT) { + // Varints: parse each one + uint8_t const* p = packed_start; + while (p < packed_end) { + int32_t elem_offset = static_cast(p - bytes - start); + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + occurrences[write_idx] = {static_cast(row), elem_offset, vbytes}; + write_idx++; + p += vbytes; + } + } else if (target_wt == WT_32BIT) { + // Fixed 32-bit: each element is 4 bytes + uint8_t const* p = packed_start; + while (p + 4 <= packed_end) { + int32_t elem_offset = static_cast(p - bytes - start); + occurrences[write_idx] = {static_cast(row), elem_offset, 4}; + write_idx++; + p += 4; + } + } else if (target_wt == WT_64BIT) { + // Fixed 64-bit: each element is 8 bytes + uint8_t const* p = packed_start; + while (p + 8 <= packed_end) { + int32_t elem_offset = static_cast(p - bytes - start); + occurrences[write_idx] = {static_cast(row), elem_offset, 8}; + write_idx++; + p += 8; + } + } + } else if (wt == target_wt) { + // Non-packed encoding: single element + int32_t data_offset, data_length; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + + int32_t abs_offset = static_cast(cur - bytes - start) + data_offset; + occurrences[write_idx] = {static_cast(row), abs_offset, data_length}; + write_idx++; + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +// ============================================================================ +// Nested message scanning kernels +// ============================================================================ + +/** + * Scan nested message fields. + * Each row represents a nested message at a specific parent location. + * This kernel finds fields within the nested message bytes. + */ +__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_parent_rows) return; + + for (int f = 0; f < num_fields; f++) { + output_locations[row * num_fields + f] = {-1, 0}; + } + + auto const& parent_loc = parent_locations[row]; + if (parent_loc.offset < 0) { return; } + + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + uint8_t const* nested_start = message_data + parent_row_start + parent_loc.offset; + uint8_t const* nested_end = nested_start + parent_loc.length; + + uint8_t const* cur = nested_start; + + while (cur < nested_end) { + proto_tag tag; + if (!decode_tag(cur, nested_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == fn) { + if (wt != field_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + int data_offset = static_cast(cur - nested_start); + + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, nested_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + if (len > static_cast(nested_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + output_locations[row * num_fields + f] = {data_offset + len_bytes, + static_cast(len)}; + } else { + int field_size = get_wire_type_size(wt, cur, nested_end); + if (field_size < 0) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + output_locations[row * num_fields + f] = {data_offset, field_size}; + } + } + } + + uint8_t const* next; + if (!skip_field(cur, nested_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Scan for child fields within repeated message occurrences. + * Each occurrence is a protobuf message, and we need to find child field locations within it. + */ +__global__ void scan_repeated_message_children_kernel( + uint8_t const* message_data, + int32_t const* msg_row_offsets, // Row offset for each occurrence + field_location const* + msg_locs, // Location of each message occurrence (offset within row, length) + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, // Output: [num_occurrences * num_child_fields] + int* error_flag) +{ + auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (occ_idx >= num_occurrences) return; + + // Initialize child locations to not found + for (int f = 0; f < num_child_fields; f++) { + child_locs[occ_idx * num_child_fields + f] = {-1, 0}; + } + + auto const& msg_loc = msg_locs[occ_idx]; + if (msg_loc.offset < 0) return; + + // Calculate absolute position of this message in the data + int32_t row_offset = msg_row_offsets[occ_idx]; + uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; + uint8_t const* msg_end = msg_start + msg_loc.length; + + uint8_t const* cur = msg_start; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + // Check against child field descriptors + for (int f = 0; f < num_child_fields; f++) { + if (child_descs[f].field_number == fn) { + bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); + if (!is_packed && wt != child_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + int data_offset = static_cast(cur - msg_start); + + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + // Store offset (after length prefix) and length + child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, + static_cast(len)}; + } else { + // For varint/fixed types, store offset and estimated length + int32_t data_length = 0; + if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (read_varint(cur, msg_end, dummy, vbytes)) { data_length = vbytes; } + } else if (wt == WT_32BIT) { + data_length = 4; + } else if (wt == WT_64BIT) { + data_length = 8; + } + child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; + } + // Don't break - last occurrence wins (protobuf semantics) + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Count repeated field occurrences within nested messages. + * Similar to count_repeated_fields_kernel but operates on nested message locations. + + */ +__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + // Initialize counts + for (int ri = 0; ri < num_repeated; ri++) { + repeated_info[row * num_repeated + ri] = {0, 0}; + } + + auto const& parent_loc = parent_locs[row]; + if (parent_loc.offset < 0) return; + + cudf::size_type row_off; + row_off = row_offsets[row] - base_offset; + + uint8_t const* msg_start = message_data + row_off + parent_loc.offset; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + // Check if this is one of our repeated fields + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_indices[ri]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + int expected_wt = schema[schema_idx].wire_type; + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + if (is_packed) { + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + + int count = 0; + if (expected_wt == WT_VARINT) { + uint8_t const* p = packed_start; + while (p < packed_end) { + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + p += vbytes; + count++; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 4); + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + count = static_cast(packed_len / 8); + } + repeated_info[row * num_repeated + ri].count += count; + repeated_info[row * num_repeated + ri].total_length += static_cast(packed_len); + } else { + int32_t data_offset, data_len; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_len)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + repeated_info[row * num_repeated + ri].count++; + repeated_info[row * num_repeated + ri].total_length += data_len; + } + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Scan for repeated field occurrences within nested messages. + */ +__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + int32_t const* occ_prefix_sums, + int num_repeated, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + auto const& parent_loc = parent_locs[row]; + if (parent_loc.offset < 0) return; + + // Prefix sum gives the write start offset for this row. + int occ_offset = occ_prefix_sums[row]; + + cudf::size_type row_off = row_offsets[row] - base_offset; + + uint8_t const* msg_start = message_data + row_off + parent_loc.offset; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; + + int occ_idx = 0; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + // Check if this is one of our repeated fields. + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_indices[ri]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + int expected_wt = schema[schema_idx].wire_type; + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + if (is_packed) { + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + + if (expected_wt == WT_VARINT) { + uint8_t const* p = packed_start; + while (p < packed_end) { + int32_t elem_offset = static_cast(p - msg_start); + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + occurrences[occ_offset + occ_idx] = {row, elem_offset, vbytes}; + occ_idx++; + p += vbytes; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + for (uint64_t i = 0; i < packed_len; i += 4) { + occurrences[occ_offset + occ_idx] = { + row, static_cast(packed_start - msg_start + i), 4}; + occ_idx++; + } + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + for (uint64_t i = 0; i < packed_len; i += 8) { + occurrences[occ_offset + occ_idx] = { + row, static_cast(packed_start - msg_start + i), 8}; + occ_idx++; + } + } + } else { + int32_t data_offset = static_cast(cur - msg_start); + int32_t data_len = 0; + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + data_offset += len_bytes; + data_len = static_cast(len); + } else if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (read_varint(cur, msg_end, dummy, vbytes)) { data_len = vbytes; } + } else if (wt == WT_32BIT) { + data_len = 4; + } else if (wt == WT_64BIT) { + data_len = 8; + } + + occurrences[occ_offset + occ_idx] = {row, data_offset, data_len}; + occ_idx++; + } + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Kernel to compute nested struct locations from child field locations. + * Replaces host-side loop that was copying data D->H, processing, then H->D. + * This is a critical performance optimization. + */ +__global__ void compute_nested_struct_locations_kernel( + field_location const* child_locs, // Child field locations from parent scan + field_location const* msg_locs, // Parent message locations + int32_t const* msg_row_offsets, // Parent message row offsets + int child_idx, // Which child field is the nested struct + int num_child_fields, // Total number of child fields per occurrence + field_location* nested_locs, // Output: nested struct locations + int32_t* nested_row_offsets, // Output: nested struct row offsets + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + // Get the nested struct location from child_locs + nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; + // Compute absolute row offset = msg_row_offset + msg_offset + nested_row_offsets[idx] = msg_row_offsets[idx] + msg_locs[idx].offset; +} + +/** + * Kernel to compute absolute grandchild parent locations from parent and child locations. + * Computes: gc_parent_abs[i] = parent[i].offset + child[i * ncf + ci].offset + * This replaces host-side loop with D->H->D copy pattern. + */ +__global__ void compute_grandchild_parent_locations_kernel( + field_location const* parent_locs, // Parent locations (row count) + field_location const* child_locs, // Child locations (row * num_child_fields) + int child_idx, // Which child field + int num_child_fields, // Total child fields per row + field_location* gc_parent_abs, // Output: absolute grandchild parent locations + int num_rows) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + + auto const& parent_loc = parent_locs[row]; + auto const& child_loc = child_locs[row * num_child_fields + child_idx]; + + if (parent_loc.offset >= 0 && child_loc.offset >= 0) { + // Absolute offset = parent offset + child's relative offset + gc_parent_abs[row].offset = parent_loc.offset + child_loc.offset; + gc_parent_abs[row].length = child_loc.length; + } else { + gc_parent_abs[row] = {-1, 0}; + } +} + +/** + * Compute virtual parent row offsets and locations for repeated message occurrences + * inside nested messages. Each occurrence becomes a virtual "row" so that + * build_nested_struct_column can recursively process the children. + */ +__global__ void compute_virtual_parents_for_nested_repeated_kernel( + repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, // original binary input list offsets + field_location const* parent_locations, // parent nested message locations + cudf::size_type* virtual_row_offsets, // output: [total_count] + field_location* virtual_parent_locs, // output: [total_count] + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto const& ploc = parent_locations[occ.row_idx]; + + virtual_row_offsets[idx] = row_list_offsets[occ.row_idx]; + + // Keep zero-length embedded messages as "present but empty". + // Protobuf allows an embedded message with length=0, which maps to a non-null + // struct with all-null children (not a null struct). + if (ploc.offset >= 0) { + virtual_parent_locs[idx] = {ploc.offset + occ.offset, occ.length}; + } else { + virtual_parent_locs[idx] = {-1, 0}; + } +} + +/** + * Kernel to compute message locations and row offsets from repeated occurrences. + * Replaces host-side loop that processed occurrences. + */ +__global__ void compute_msg_locations_from_occurrences_kernel( + repeated_occurrence const* occurrences, // Repeated field occurrences + cudf::size_type const* list_offsets, // List offsets for rows + cudf::size_type base_offset, // Base offset to subtract + field_location* msg_locs, // Output: message locations + int32_t* msg_row_offsets, // Output: message row offsets + int total_count) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + msg_row_offsets[idx] = static_cast(list_offsets[occ.row_idx] - base_offset); + msg_locs[idx] = {occ.offset, occ.length}; +} + +/** + * Extract a single field's locations from a 2D strided array on the GPU. + * Replaces a D2H + CPU loop + H2D pattern for nested message location extraction. + */ +__global__ void extract_strided_locations_kernel(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + parent_locs[row] = nested_locations[row * num_fields + field_idx]; +} + +// ============================================================================ +// Kernel to check required fields after scan pass +// ============================================================================ + +/** + * Check if any required fields are missing (offset < 0) and set error flag. + * This is called after the scan pass to validate required field constraints. + */ +__global__ void check_required_fields_kernel( + field_location const* locations, // [num_rows * num_fields] + uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) + int num_fields, + int num_rows, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + for (int f = 0; f < num_fields; f++) { + if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { + // Required field is missing - set error flag + set_error_once(error_flag, ERR_REQUIRED); + return; // No need to check other fields for this row + } + } +} + +/** + * Validate enum values against a set of valid values. + * If a value is not in the valid set: + * 1. Mark the field as invalid (valid[row] = false) + * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) + * + * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is + * encountered, the entire struct row is set to null (not just the enum field). + * + * The valid_values array must be sorted for binary search. + * + * @note Time complexity: O(log(num_valid_values)) per row. + + */ +__global__ void validate_enum_values_kernel( + int32_t const* values, // [num_rows] extracted enum values + bool* valid, // [num_rows] field validity flags (will be modified) + bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) + int32_t const* valid_enum_values, // sorted array of valid enum values + int num_valid_values, // size of valid_enum_values + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + // Skip if already invalid (field was missing) - missing field is not an enum error + if (!valid[row]) return; + + int32_t val = values[row]; + + // Binary search for the value in valid_enum_values + int left = 0; + int right = num_valid_values - 1; + bool found = false; + + while (left <= right) { + int mid = left + (right - left) / 2; + if (valid_enum_values[mid] == val) { + found = true; + break; + } else if (valid_enum_values[mid] < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // If not found, mark as invalid + if (!found) { + valid[row] = false; + // Also mark the row as having an invalid enum - this will null the entire struct row + row_has_invalid_enum[row] = true; + } +} + +/** + * Compute output UTF-8 length for enum-as-string rows. + * Invalid/missing values produce length 0 (null row/field semantics handled by valid[] and + * row_has_invalid_enum). + + */ +__global__ void compute_enum_string_lengths_kernel( + int32_t const* values, // [num_rows] enum numeric values + bool const* valid, // [num_rows] field validity + int32_t const* valid_enum_values, // sorted enum numeric values + int32_t const* enum_name_offsets, // [num_valid_values + 1] + int num_valid_values, + int32_t* lengths, // [num_rows] + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + if (!valid[row]) { + lengths[row] = 0; + return; + } + + int32_t val = values[row]; + int left = 0; + int right = num_valid_values - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + int32_t mid_val = valid_enum_values[mid]; + if (mid_val == val) { + lengths[row] = enum_name_offsets[mid + 1] - enum_name_offsets[mid]; + return; + } else if (mid_val < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // Should not happen when validate_enum_values_kernel has already run, but keep safe. + lengths[row] = 0; +} + +/** + * Copy enum-as-string UTF-8 bytes into output chars buffer using precomputed row offsets. + */ +__global__ void copy_enum_string_chars_kernel( + int32_t const* values, // [num_rows] enum numeric values + bool const* valid, // [num_rows] field validity + int32_t const* valid_enum_values, // sorted enum numeric values + int32_t const* enum_name_offsets, // [num_valid_values + 1] + uint8_t const* enum_name_chars, // concatenated enum UTF-8 names + int num_valid_values, + int32_t const* output_offsets, // [num_rows + 1] + char* out_chars, // [total_chars] + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + if (!valid[row]) return; + + int32_t val = values[row]; + int left = 0; + int right = num_valid_values - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + int32_t mid_val = valid_enum_values[mid]; + if (mid_val == val) { + int32_t src_begin = enum_name_offsets[mid]; + int32_t src_end = enum_name_offsets[mid + 1]; + int32_t dst_begin = output_offsets[row]; + for (int32_t i = 0; i < (src_end - src_begin); ++i) { + out_chars[dst_begin + i] = static_cast(enum_name_chars[src_begin + i]); + } + return; + } else if (mid_val < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } +} + +} // namespace spark_rapids_jni::protobuf_detail From 89e6e8cde6470204237bda0ba88c2d7188ea1b69 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 2 Mar 2026 17:41:35 +0800 Subject: [PATCH 035/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 13 +- src/main/cpp/src/protobuf_builders.cu | 61 ++++---- src/main/cpp/src/protobuf_common.cuh | 211 +++++++++++++++++--------- 3 files changed, 169 insertions(+), 116 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index d738ac863f..48c922acbe 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -573,12 +573,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; extract_varint_kernel - <<>>(message_data, - rep_loc, - total_count, - enum_ints.data(), - nullptr, - d_error.data()); + <<>>( + message_data, rep_loc, total_count, enum_ints.data(), nullptr, d_error.data()); // 2. Build device-side enum lookup tables rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); @@ -639,9 +635,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& total_count); // 5. Build string offsets - auto [str_offs_col, total_chars] = - cudf::strings::detail::make_offsets_child_column( - elem_lengths.begin(), elem_lengths.end(), stream, mr); + auto [str_offs_col, total_chars] = cudf::strings::detail::make_offsets_child_column( + elem_lengths.begin(), elem_lengths.end(), stream, mr); // 6. Copy string chars rmm::device_uvector chars(total_chars, stream, mr); diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index a06b3b7b1d..3582e61633 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -253,7 +253,6 @@ std::unique_ptr make_empty_list_column(std::unique_ptr build_enum_string_column( rmm::device_uvector& enum_values, rmm::device_uvector& valid, @@ -322,8 +321,8 @@ std::unique_ptr build_enum_string_column( lengths.data(), num_rows); - auto [offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( - lengths.begin(), lengths.end(), stream, mr); + auto [offsets_col, total_chars] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { @@ -421,13 +420,13 @@ std::unique_ptr build_repeated_string_column( CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); if (total_chars > 0) { RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - copy_varlen_data_kernel<<>>( - message_data, - loc_provider, - total_count, - str_offsets_col->view().data(), - chars.data(), - d_error.data()); + copy_varlen_data_kernel + <<>>(message_data, + loc_provider, + total_count, + str_offsets_col->view().data(), + chars.data(), + d_error.data()); } std::unique_ptr child_col; @@ -1299,33 +1298,31 @@ std::unique_ptr build_repeated_child_list_column( std::unique_ptr child_values; auto const rep_blocks = static_cast((total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); - NestedRepeatedLocationProvider nr_loc{ - row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + NestedRepeatedLocationProvider nr_loc{row_offsets, base_offset, parent_locs, d_rep_occs.data()}; if (elem_type_id == cudf::type_id::BOOL8 || elem_type_id == cudf::type_id::INT32 || elem_type_id == cudf::type_id::UINT32 || elem_type_id == cudf::type_id::INT64 || elem_type_id == cudf::type_id::UINT64 || elem_type_id == cudf::type_id::FLOAT32 || elem_type_id == cudf::type_id::FLOAT64) { - child_values = - extract_typed_column(cudf::data_type{elem_type_id}, - schema[child_schema_idx].encoding, - message_data, - nr_loc, - total_rep_count, - rep_blocks, - THREADS_PER_BLOCK, - false, - 0, - 0.0, - false, - std::vector{}, - child_schema_idx, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - stream, - mr); + child_values = extract_typed_column(cudf::data_type{elem_type_id}, + schema[child_schema_idx].encoding, + message_data, + nr_loc, + total_rep_count, + rep_blocks, + THREADS_PER_BLOCK, + false, + 0, + 0.0, + false, + std::vector{}, + child_schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr); } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { bool as_bytes = (elem_type_id == cudf::type_id::LIST); auto valid_fn = [] __device__(cudf::size_type) { return true; }; diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 4a625581b8..405131a998 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -632,7 +632,6 @@ inline std::pair make_null_mask_from_valid( return cudf::detail::valid_if(begin, end, pred, stream, mr); } - template std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, int num_rows, @@ -817,88 +816,150 @@ std::unique_ptr make_empty_struct_column_with_schema( return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); } - // ============================================================================ // Forward declarations of non-template __global__ kernels // ============================================================================ -__global__ void scan_all_fields_kernel( - cudf::column_device_view const d_in, field_descriptor const* field_descs, int num_fields, - int const* field_lookup, int field_lookup_size, field_location* locations, int* error_flag); - -__global__ void count_repeated_fields_kernel( - cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, - int num_fields, int depth_level, repeated_field_info* repeated_info, int num_repeated_fields, - int const* repeated_field_indices, field_location* nested_locations, int num_nested_fields, - int const* nested_field_indices, int* error_flag); - -__global__ void scan_repeated_field_occurrences_kernel( - cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, - int schema_idx, int depth_level, int32_t const* output_offsets, - repeated_occurrence* occurrences, int* error_flag); - -__global__ void scan_nested_message_fields_kernel( - uint8_t const* message_data, cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, field_location const* parent_locations, - int num_parent_rows, field_descriptor const* field_descs, int num_fields, - field_location* output_locations, int* error_flag); - -__global__ void scan_repeated_message_children_kernel( - uint8_t const* message_data, int32_t const* msg_row_offsets, field_location const* msg_locs, - int num_occurrences, field_descriptor const* child_descs, int num_child_fields, - field_location* child_locs, int* error_flag); - -__global__ void count_repeated_in_nested_kernel( - uint8_t const* message_data, cudf::size_type const* row_offsets, cudf::size_type base_offset, - field_location const* parent_locs, int num_rows, device_nested_field_descriptor const* schema, - int num_fields, repeated_field_info* repeated_info, int num_repeated, - int const* repeated_indices, int* error_flag); - -__global__ void scan_repeated_in_nested_kernel( - uint8_t const* message_data, cudf::size_type const* row_offsets, cudf::size_type base_offset, - field_location const* parent_locs, int num_rows, device_nested_field_descriptor const* schema, - int num_fields, int32_t const* occ_prefix_sums, int num_repeated, - int const* repeated_indices, repeated_occurrence* occurrences, int* error_flag); - -__global__ void compute_nested_struct_locations_kernel( - field_location const* child_locs, field_location const* msg_locs, - int32_t const* msg_row_offsets, int child_idx, int num_child_fields, - field_location* nested_locs, int32_t* nested_row_offsets, int total_count); - -__global__ void compute_grandchild_parent_locations_kernel( - field_location const* parent_locs, field_location const* child_locs, - int child_idx, int num_child_fields, field_location* gc_parent_abs, int num_rows); +__global__ void scan_all_fields_kernel(cudf::column_device_view const d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag); + +__global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag); + +__global__ void scan_repeated_field_occurrences_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int schema_idx, + int depth_level, + int32_t const* output_offsets, + repeated_occurrence* occurrences, + int* error_flag); + +__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag); + +__global__ void scan_repeated_message_children_kernel(uint8_t const* message_data, + int32_t const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag); + +__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag); + +__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + int32_t const* occ_prefix_sums, + int num_repeated, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag); + +__global__ void compute_nested_struct_locations_kernel(field_location const* child_locs, + field_location const* msg_locs, + int32_t const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + int32_t* nested_row_offsets, + int total_count); + +__global__ void compute_grandchild_parent_locations_kernel(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows); __global__ void compute_virtual_parents_for_nested_repeated_kernel( - repeated_occurrence const* occurrences, cudf::size_type const* row_list_offsets, - field_location const* parent_locations, cudf::size_type* virtual_row_offsets, - field_location* virtual_parent_locs, int total_count); + repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count); __global__ void compute_msg_locations_from_occurrences_kernel( - repeated_occurrence const* occurrences, cudf::size_type const* list_offsets, - cudf::size_type base_offset, field_location* msg_locs, int32_t* msg_row_offsets, + repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + int32_t* msg_row_offsets, int total_count); -__global__ void extract_strided_locations_kernel( - field_location const* nested_locations, int field_idx, int num_fields, - field_location* parent_locs, int num_rows); - -__global__ void check_required_fields_kernel( - field_location const* locations, uint8_t const* is_required, int num_fields, - int num_rows, int* error_flag); - -__global__ void validate_enum_values_kernel( - int32_t const* values, bool* valid, bool* row_has_invalid_enum, - int32_t const* valid_enum_values, int num_valid_values, int num_rows); - -__global__ void compute_enum_string_lengths_kernel( - int32_t const* values, bool const* valid, int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, int num_valid_values, int32_t* lengths, int num_rows); - -__global__ void copy_enum_string_chars_kernel( - int32_t const* values, bool const* valid, int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, uint8_t const* enum_name_chars, int num_valid_values, - int32_t const* output_offsets, char* out_chars, int num_rows); +__global__ void extract_strided_locations_kernel(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows); + +__global__ void check_required_fields_kernel(field_location const* locations, + uint8_t const* is_required, + int num_fields, + int num_rows, + int* error_flag); + +__global__ void validate_enum_values_kernel(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows); + +__global__ void compute_enum_string_lengths_kernel(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows); + +__global__ void copy_enum_string_chars_kernel(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows); // ============================================================================ // Forward declarations of builder/utility functions @@ -1036,8 +1097,8 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( extract_lengths_kernel<<>>( length_provider, num_rows, lengths.data(), has_default, def_len); - auto [offsets_col, total_size] = cudf::strings::detail::make_offsets_child_column( - lengths.begin(), lengths.end(), stream, mr); + auto [offsets_col, total_size] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); rmm::device_uvector chars(total_size, stream, mr); if (total_size > 0) { From f37d3c91f71f087e144c8545a2831ed6a8d0fcf7 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 3 Mar 2026 12:13:48 +0800 Subject: [PATCH 036/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 50 ++++++++++------------------ src/main/cpp/src/protobuf_common.cuh | 14 ++++++-- 2 files changed, 30 insertions(+), 34 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 48c922acbe..6a77c44780 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -16,6 +16,8 @@ #include "protobuf_common.cuh" +#include + using namespace spark_rapids_jni::protobuf_detail; namespace spark_rapids_jni { @@ -80,6 +82,16 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } + // Extract shared input data pointers (used by scalar, repeated, and nested sections) + cudf::lists_column_view const in_list_view(binary_input); + auto const* message_data = reinterpret_cast(in_list_view.child().data()); + auto const* list_offsets = in_list_view.offsets().data(); + + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + // Copy schema to device std::vector h_device_schema(num_fields); for (int i = 0; i < num_fields; i++) { @@ -251,17 +263,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } - // Extract scalar values (reusing existing extraction logic) - cudf::lists_column_view const in_list_view(binary_input); - auto const* message_data = - reinterpret_cast(in_list_view.child().data()); - auto const* list_offsets = in_list_view.offsets().data(); - - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - for (int i = 0; i < num_scalar; i++) { int schema_idx = scalar_field_indices[i]; auto const dt = schema_output_types[schema_idx]; @@ -401,15 +402,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Process repeated fields if (num_repeated > 0) { - cudf::lists_column_view const in_list_view(binary_input); - auto const* list_offsets = in_list_view.offsets().data(); - auto const* message_data = - reinterpret_cast(in_list_view.child().data()); - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - for (int ri = 0; ri < num_repeated; ri++) { int schema_idx = repeated_field_indices[ri]; auto element_type = schema_output_types[schema_idx]; @@ -424,6 +416,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int64_t total_count = thrust::reduce( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), int64_t{0}); + CUDF_EXPECTS(total_count <= std::numeric_limits::max(), + "Total repeated element count exceeds INT32_MAX"); if (total_count > 0) { // Build offsets for occurrence scanning on GPU (performance fix!) @@ -434,8 +428,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occ_offsets.data(), 0); // Set last element + int32_t total_count_i32 = static_cast(total_count); CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data() + num_rows, - &total_count, + &total_count_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); @@ -664,8 +659,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_field_counts.end(), list_offs.begin(), 0); + int32_t tc_i32 = static_cast(total_count); CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, + &tc_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); @@ -813,16 +809,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Process nested struct fields (Phase 2) if (num_nested > 0) { - cudf::lists_column_view const in_list_view(binary_input); - auto const* message_data = - reinterpret_cast(in_list_view.child().data()); - auto const* list_offsets = in_list_view.offsets().data(); - - cudf::size_type base_offset = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - for (int ni = 0; ni < num_nested; ni++) { int parent_schema_idx = nested_field_indices[ni]; diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 405131a998..5bd0b43fb8 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -490,9 +490,19 @@ __global__ void extract_varint_kernel(uint8_t const* message_data, int32_t data_offset = 0; auto loc = loc_provider.get(idx, data_offset); + // For BOOL8 (uint8_t), protobuf spec says any non-zero varint is true. + // A raw static_cast would silently truncate values >= 256 to 0. + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + if (loc.offset < 0) { if (has_default) { - out[idx] = static_cast(default_value); + write_value(&out[idx], static_cast(default_value)); if (valid) valid[idx] = true; } else { if (valid) valid[idx] = false; @@ -512,7 +522,7 @@ __global__ void extract_varint_kernel(uint8_t const* message_data, } if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - out[idx] = static_cast(v); + write_value(&out[idx], v); if (valid) valid[idx] = true; } From 78b1c60921582b3d051483aadecf9cd193edab05 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 3 Mar 2026 15:37:00 +0800 Subject: [PATCH 037/107] address cc comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 181 ++----- src/main/cpp/src/protobuf_builders.cu | 127 ++++- src/main/cpp/src/protobuf_common.cuh | 17 +- src/main/cpp/src/protobuf_kernels.cu | 481 ++++++++---------- .../com/nvidia/spark/rapids/jni/Protobuf.java | 59 --- .../rapids/jni/ProtobufSchemaDescriptor.java | 45 +- .../nvidia/spark/rapids/jni/ProtobufTest.java | 148 ++++-- 7 files changed, 497 insertions(+), 561 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 6a77c44780..6293e5d299 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -131,14 +131,28 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Error flag rmm::device_uvector d_error(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + auto error_message = [](int code) -> char const* { + switch (code) { + case ERR_BOUNDS: return "Protobuf decode error: message data out of bounds"; + case ERR_VARINT: return "Protobuf decode error: invalid or truncated varint"; + case ERR_FIELD_NUMBER: return "Protobuf decode error: invalid field number"; + case ERR_WIRE_TYPE: return "Protobuf decode error: unexpected wire type"; + case ERR_OVERFLOW: return "Protobuf decode error: length-delimited field overflows message"; + case ERR_FIELD_SIZE: return "Protobuf decode error: invalid field size"; + case ERR_SKIP: return "Protobuf decode error: unable to skip unknown field"; + case ERR_FIXED_LEN: + return "Protobuf decode error: invalid fixed-width or packed field length"; + case ERR_REQUIRED: return "Protobuf decode error: missing required field"; + default: return "Protobuf decode error: unknown error"; + } + }; auto check_error_and_throw = [&]() { if (!fail_on_errors) return; int h_error = 0; CUDF_CUDA_TRY(cudaMemcpyAsync( &h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - CUDF_EXPECTS(h_error == 0, - "Malformed protobuf message, unsupported wire type, or missing required field"); + if (h_error != 0) { throw cudf::logic_error(error_message(h_error)); } }; // Enum validation support (PERMISSIVE mode) @@ -193,9 +207,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& check_error_and_throw(); } - // For scalar fields at depth 0, use the existing scan_all_fields_kernel - // Use a map to store columns by schema index, then assemble in order at the end - std::map> column_map; + // Store decoded columns by schema index for ordered assembly at the end. + std::vector> column_map(num_fields); // Process scalar fields using existing infrastructure if (num_scalar > 0) { @@ -558,140 +571,20 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& schema_idx < static_cast(enum_names.size()) && !enum_valid_values[schema_idx].empty() && enum_valid_values[schema_idx].size() == enum_names[schema_idx].size()) { - // Repeated enum-as-string: extract varints, then convert to strings. - auto const& valid_enums = enum_valid_values[schema_idx]; - auto const& name_bytes = enum_names[schema_idx]; - - // 1. Extract enum integer values from occurrences - rmm::device_uvector enum_ints(total_count, stream, mr); - auto const rep_blocks = - static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); - RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; - extract_varint_kernel - <<>>( - message_data, rep_loc, total_count, enum_ints.data(), nullptr, d_error.data()); - - // 2. Build device-side enum lookup tables - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int32_t total_name_chars = 0; - for (size_t k = 0; k < name_bytes.size(); ++k) { - total_name_chars += static_cast(name_bytes[k].size()); - h_name_offsets[k + 1] = total_name_chars; - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& nm : name_bytes) { - if (!nm.empty()) { - std::copy(nm.data(), nm.data() + nm.size(), h_name_chars.data() + cursor); - cursor += static_cast(nm.size()); - } - } - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - // 3. Validate enum values (sets row_has_invalid_enum for PERMISSIVE mode). - // We also need per-element validity for string building. - rmm::device_uvector elem_valid(total_count, stream, mr); - thrust::fill(rmm::exec_policy(stream), elem_valid.data(), elem_valid.end(), true); - // validate_enum_values_kernel works on per-row basis; here we need per-element. - // Binary-search each element inline via the lengths kernel below. - - // 4. Compute per-element string lengths - rmm::device_uvector elem_lengths(total_count, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_ints.data(), - elem_valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - elem_lengths.data(), - total_count); - - // 5. Build string offsets - auto [str_offs_col, total_chars] = cudf::strings::detail::make_offsets_child_column( - elem_lengths.begin(), elem_lengths.end(), stream, mr); - - // 6. Copy string chars - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - enum_ints.data(), - elem_valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - str_offs_col->view().data(), - chars.data(), - total_count); - } - - // 7. Assemble LIST column - auto child_col = cudf::make_strings_column( - total_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); - - // Build list offsets from per-row counts - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_field_counts.begin(), - d_field_counts.end(), - list_offs.begin(), - 0); - int32_t tc_i32 = static_cast(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &tc_i32, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - auto list_offs_col = - std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - - auto input_null_count = binary_input.null_count(); - if (input_null_count > 0) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - column_map[schema_idx] = cudf::make_lists_column(num_rows, - std::move(list_offs_col), - std::move(child_col), - input_null_count, - std::move(null_mask), - stream, - mr); - } else { - column_map[schema_idx] = cudf::make_lists_column(num_rows, - std::move(list_offs_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); - } + column_map[schema_idx] = + build_repeated_enum_string_column(binary_input, + message_data, + list_offsets, + base_offset, + d_field_counts, + d_occurrences, + total_count, + num_rows, + enum_valid_values[schema_idx], + enum_names[schema_idx], + d_error, + stream, + mr); } else { column_map[schema_idx] = build_repeated_string_column(binary_input, message_data, @@ -854,9 +747,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& std::vector> top_level_children; for (int i = 0; i < num_fields; i++) { if (schema[i].parent_idx == -1) { // Top-level field - auto it = column_map.find(i); - if (it != column_map.end()) { - top_level_children.push_back(std::move(it->second)); + if (column_map[i]) { + top_level_children.push_back(std::move(column_map[i])); } else { if (schema[i].is_repeated) { auto const element_type = schema_output_types[i]; @@ -882,10 +774,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& CUDF_CUDA_TRY( cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); - if (fail_on_errors) { - CUDF_EXPECTS(h_error == 0, - "Malformed protobuf message, unsupported wire type, or missing required field"); - } + if (fail_on_errors && h_error != 0) { throw cudf::logic_error(error_message(h_error)); } // Build final struct with PERMISSIVE mode null mask for invalid enums cudf::size_type struct_null_count = 0; diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 3582e61633..50dc9e6617 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -343,6 +343,131 @@ std::unique_ptr build_enum_string_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); } +std::unique_ptr build_repeated_enum_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const rep_blocks = + static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); + + // 1. Extract enum integer values from occurrences + rmm::device_uvector enum_ints(total_count, stream, mr); + RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; + extract_varint_kernel<<>>( + message_data, rep_loc, total_count, enum_ints.data(), nullptr, d_error.data()); + + // 2. Build device-side enum lookup tables + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int32_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + h_name_offsets[k + 1] = total_name_chars; + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& nm : enum_name_bytes) { + if (!nm.empty()) { + std::copy(nm.data(), nm.data() + nm.size(), h_name_chars.data() + cursor); + cursor += static_cast(nm.size()); + } + } + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + } + + // 3. Per-element validity + rmm::device_uvector elem_valid(total_count, stream, mr); + thrust::fill(rmm::exec_policy(stream), elem_valid.data(), elem_valid.end(), true); + + // 4. Compute per-element string lengths + rmm::device_uvector elem_lengths(total_count, stream, mr); + compute_enum_string_lengths_kernel<<>>( + enum_ints.data(), + elem_valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + elem_lengths.data(), + total_count); + + // 5. Build string offsets + auto [str_offs_col, total_chars] = cudf::strings::detail::make_offsets_child_column( + elem_lengths.begin(), elem_lengths.end(), stream, mr); + + // 6. Copy string chars + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + enum_ints.data(), + elem_valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + str_offs_col->view().data(), + chars.data(), + total_count); + } + + // 7. Assemble strings child column + auto child_col = cudf::make_strings_column( + total_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + + // 8. Build LIST column with list offsets from per-row counts + rmm::device_uvector lo(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), lo.begin(), 0); + int32_t tc_i32 = static_cast(total_count); + CUDF_CUDA_TRY(cudaMemcpyAsync( + lo.data() + num_rows, &tc_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + + auto list_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, lo.release(), rmm::device_buffer{}, 0); + + auto input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(list_offs_col), + std::move(child_col), + input_null_count, + std::move(null_mask), + stream, + mr); + } + return cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); +} + std::unique_ptr build_repeated_string_column( cudf::column_view const& binary_input, uint8_t const* message_data, @@ -1288,9 +1413,7 @@ std::unique_ptr build_repeated_child_list_column( parent_locs, num_parent_rows, d_rep_schema.data(), - 1, list_offs.data(), - 1, d_rep_indices.data(), d_rep_occs.data(), d_error.data()); diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 5bd0b43fb8..eb99e26e98 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -895,9 +895,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, field_location const* parent_locs, int num_rows, device_nested_field_descriptor const* schema, - int num_fields, int32_t const* occ_prefix_sums, - int num_repeated, int const* repeated_indices, repeated_occurrence* occurrences, int* error_flag); @@ -997,6 +995,21 @@ std::unique_ptr build_enum_string_column( rmm::device_async_resource_ref mr); // Complex builder forward declarations +std::unique_ptr build_repeated_enum_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + std::unique_ptr build_repeated_string_column( cudf::column_view const& binary_input, uint8_t const* message_data, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index a12a4c6abe..d5095972f3 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -111,6 +111,167 @@ __global__ void scan_all_fields_kernel( } } +// ============================================================================ +// Shared device functions for repeated field processing +// ============================================================================ + +/** + * Count a single repeated field occurrence (packed or unpacked). + * Updates info.count and info.total_length. + * Returns false on error (error_flag set), true on success. + */ +__device__ bool count_repeated_element(uint8_t const* cur, + uint8_t const* msg_end, + int wt, + int expected_wt, + repeated_field_info& info, + int* error_flag) +{ + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } + + if (is_packed) { + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + + int count = 0; + if (expected_wt == WT_VARINT) { + uint8_t const* p = packed_start; + while (p < packed_end) { + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + p += vbytes; + count++; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + count = static_cast(packed_len / 4); + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + count = static_cast(packed_len / 8); + } + + info.count += count; + info.total_length += static_cast(packed_len); + } else { + int32_t data_offset, data_length; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return false; + } + info.count++; + info.total_length += data_length; + } + return true; +} + +/** + * Record a single repeated field occurrence (packed or unpacked). + * Writes to occurrences[write_idx] and advances write_idx. + * Offsets are computed relative to msg_base. + * Returns false on error (error_flag set), true on success. + */ +__device__ bool scan_repeated_element(uint8_t const* cur, + uint8_t const* msg_end, + uint8_t const* msg_base, + int wt, + int expected_wt, + int32_t row, + repeated_occurrence* occurrences, + int& write_idx, + int* error_flag) +{ + bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } + + if (is_packed) { + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + uint8_t const* packed_start = cur + len_bytes; + uint8_t const* packed_end = packed_start + packed_len; + if (packed_end > msg_end) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + + if (expected_wt == WT_VARINT) { + uint8_t const* p = packed_start; + while (p < packed_end) { + int32_t elem_offset = static_cast(p - msg_base); + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + occurrences[write_idx] = {row, elem_offset, vbytes}; + write_idx++; + p += vbytes; + } + } else if (expected_wt == WT_32BIT) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + for (uint64_t i = 0; i < packed_len; i += 4) { + occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 4}; + write_idx++; + } + } else if (expected_wt == WT_64BIT) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + for (uint64_t i = 0; i < packed_len; i += 8) { + occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 8}; + write_idx++; + } + } + } else { + int32_t data_offset, data_length; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return false; + } + int32_t abs_offset = static_cast(cur - msg_base) + data_offset; + occurrences[write_idx] = {row, abs_offset, data_length}; + write_idx++; + } + return true; +} + // ============================================================================ // Pass 1b: Count repeated fields kernel // ============================================================================ @@ -167,80 +328,17 @@ __global__ void count_repeated_fields_kernel( int fn = tag.field_number; int wt = tag.wire_type; - // Check repeated fields at this depth for (int i = 0; i < num_repeated_fields; i++) { int schema_idx = repeated_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - int expected_wt = schema[schema_idx].wire_type; - - // Handle both packed and unpacked encoding for repeated fields - // Packed encoding uses wire type LEN (2) even for scalar types - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - - if (!is_packed && wt != expected_wt) { - set_error_once(error_flag, ERR_WIRE_TYPE); + if (!count_repeated_element(cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[row * num_repeated_fields + i], + error_flag)) { return; } - - if (is_packed) { - // Packed encoding: read length, then count elements inside - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - - // Count elements based on type - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - int count = 0; - if (expected_wt == WT_VARINT) { - // Count varints in the packed data - uint8_t const* p = packed_start; - while (p < packed_end) { - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - p += vbytes; - count++; - } - } else if (expected_wt == WT_32BIT) { - if ((packed_len % 4) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 4); - } else if (expected_wt == WT_64BIT) { - if ((packed_len % 8) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 8); - } - - repeated_info[row * num_repeated_fields + i].count += count; - repeated_info[row * num_repeated_fields + i].total_length += - static_cast(packed_len); - } else { - // Non-packed encoding: single element - int32_t data_offset, data_length; - if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - - repeated_info[row * num_repeated_fields + i].count++; - repeated_info[row * num_repeated_fields + i].total_length += data_length; - } } } @@ -317,71 +415,19 @@ __global__ void scan_repeated_field_occurrences_kernel( int wt = tag.wire_type; if (fn == target_fn) { - // Check for packed encoding: wire type LEN but expected non-LEN bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); - - if (is_packed) { - // Packed encoding: multiple elements in a length-delimited blob - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); + if (is_packed || wt == target_wt) { + if (!scan_repeated_element(cur, + msg_end, + bytes + start, + wt, + target_wt, + static_cast(row), + occurrences, + write_idx, + error_flag)) { return; } - - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - // Record each element in the packed blob - if (target_wt == WT_VARINT) { - // Varints: parse each one - uint8_t const* p = packed_start; - while (p < packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - occurrences[write_idx] = {static_cast(row), elem_offset, vbytes}; - write_idx++; - p += vbytes; - } - } else if (target_wt == WT_32BIT) { - // Fixed 32-bit: each element is 4 bytes - uint8_t const* p = packed_start; - while (p + 4 <= packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); - occurrences[write_idx] = {static_cast(row), elem_offset, 4}; - write_idx++; - p += 4; - } - } else if (target_wt == WT_64BIT) { - // Fixed 64-bit: each element is 8 bytes - uint8_t const* p = packed_start; - while (p + 8 <= packed_end) { - int32_t elem_offset = static_cast(p - bytes - start); - occurrences[write_idx] = {static_cast(row), elem_offset, 8}; - write_idx++; - p += 8; - } - } - } else if (wt == target_wt) { - // Non-packed encoding: single element - int32_t data_offset, data_length; - if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - - int32_t abs_offset = static_cast(cur - bytes - start) + data_offset; - occurrences[write_idx] = {static_cast(row), abs_offset, data_length}; - write_idx++; } } @@ -608,69 +654,17 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, int fn = tag.field_number; int wt = tag.wire_type; - // Check if this is one of our repeated fields for (int ri = 0; ri < num_repeated; ri++) { int schema_idx = repeated_indices[ri]; if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int expected_wt = schema[schema_idx].wire_type; - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - - if (!is_packed && wt != expected_wt) { - set_error_once(error_flag, ERR_WIRE_TYPE); + if (!count_repeated_element(cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[row * num_repeated + ri], + error_flag)) { return; } - - if (is_packed) { - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - int count = 0; - if (expected_wt == WT_VARINT) { - uint8_t const* p = packed_start; - while (p < packed_end) { - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - p += vbytes; - count++; - } - } else if (expected_wt == WT_32BIT) { - if ((packed_len % 4) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 4); - } else if (expected_wt == WT_64BIT) { - if ((packed_len % 8) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - count = static_cast(packed_len / 8); - } - repeated_info[row * num_repeated + ri].count += count; - repeated_info[row * num_repeated + ri].total_length += static_cast(packed_len); - } else { - int32_t data_offset, data_len; - if (!get_field_data_location(cur, msg_end, wt, data_offset, data_len)) { - set_error_once(error_flag, ERR_FIELD_SIZE); - return; - } - repeated_info[row * num_repeated + ri].count++; - repeated_info[row * num_repeated + ri].total_length += data_len; - } } } @@ -684,7 +678,9 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, } /** - * Scan for repeated field occurrences within nested messages. + * Scan for repeated field occurrences of a single repeated field within nested + * messages. Must be launched once per repeated field — the caller passes + * exactly one schema index via repeated_indices[0]. */ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, cudf::size_type const* row_offsets, @@ -692,9 +688,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, field_location const* parent_locs, int num_rows, device_nested_field_descriptor const* schema, - int num_fields, int32_t const* occ_prefix_sums, - int num_repeated, int const* repeated_indices, repeated_occurrence* occurrences, int* error_flag) @@ -705,16 +699,14 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, auto const& parent_loc = parent_locs[row]; if (parent_loc.offset < 0) return; - // Prefix sum gives the write start offset for this row. - int occ_offset = occ_prefix_sums[row]; - cudf::size_type row_off = row_offsets[row] - base_offset; uint8_t const* msg_start = message_data + row_off + parent_loc.offset; uint8_t const* msg_end = msg_start + parent_loc.length; uint8_t const* cur = msg_start; - int occ_idx = 0; + int write_idx = occ_prefix_sums[row]; + int schema_idx = repeated_indices[0]; while (cur < msg_end) { proto_tag tag; @@ -722,92 +714,17 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, int fn = tag.field_number; int wt = tag.wire_type; - // Check if this is one of our repeated fields. - for (int ri = 0; ri < num_repeated; ri++) { - int schema_idx = repeated_indices[ri]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { - int expected_wt = schema[schema_idx].wire_type; - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); - - if (!is_packed && wt != expected_wt) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } - - if (is_packed) { - uint64_t packed_len; - int len_bytes; - if (!read_varint(cur, msg_end, packed_len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - - if (expected_wt == WT_VARINT) { - uint8_t const* p = packed_start; - while (p < packed_end) { - int32_t elem_offset = static_cast(p - msg_start); - uint64_t dummy; - int vbytes; - if (!read_varint(p, packed_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - occurrences[occ_offset + occ_idx] = {row, elem_offset, vbytes}; - occ_idx++; - p += vbytes; - } - } else if (expected_wt == WT_32BIT) { - if ((packed_len % 4) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - for (uint64_t i = 0; i < packed_len; i += 4) { - occurrences[occ_offset + occ_idx] = { - row, static_cast(packed_start - msg_start + i), 4}; - occ_idx++; - } - } else if (expected_wt == WT_64BIT) { - if ((packed_len % 8) != 0) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - for (uint64_t i = 0; i < packed_len; i += 8) { - occurrences[occ_offset + occ_idx] = { - row, static_cast(packed_start - msg_start + i), 8}; - occ_idx++; - } - } - } else { - int32_t data_offset = static_cast(cur - msg_start); - int32_t data_len = 0; - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - data_offset += len_bytes; - data_len = static_cast(len); - } else if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { data_len = vbytes; } - } else if (wt == WT_32BIT) { - data_len = 4; - } else if (wt == WT_64BIT) { - data_len = 8; - } - - occurrences[occ_offset + occ_idx] = {row, data_offset, data_len}; - occ_idx++; - } + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + if (!scan_repeated_element(cur, + msg_end, + msg_start, + wt, + schema[schema_idx].wire_type, + static_cast(row), + occurrences, + write_idx, + error_flag)) { + return; } } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index bd8f9632d0..43f2a3eb01 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -82,65 +82,6 @@ public static ColumnVector decodeToStruct(ColumnView binaryInput, return new ColumnVector(handle); } - /** - * Decode protobuf messages using individual parallel arrays. - * - * @deprecated Use {@link #decodeToStruct(ColumnView, ProtobufSchemaDescriptor, boolean)} instead. - */ - @Deprecated - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int[] fieldNumbers, - int[] parentIndices, - int[] depthLevels, - int[] wireTypes, - int[] outputTypeIds, - int[] encodings, - boolean[] isRepeated, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - int[][] enumValidValues, - byte[][][] enumNames, - boolean failOnErrors) { - return decodeToStruct(binaryInput, - new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, - wireTypes, outputTypeIds, encodings, isRepeated, isRequired, - hasDefaultValue, defaultInts, defaultFloats, defaultBools, - defaultStrings, enumValidValues, enumNames), - failOnErrors); - } - - /** - * Backward-compatible overload without enum name mappings. - * - * @deprecated Use {@link #decodeToStruct(ColumnView, ProtobufSchemaDescriptor, boolean)} instead. - */ - @Deprecated - public static ColumnVector decodeToStruct(ColumnView binaryInput, - int[] fieldNumbers, - int[] parentIndices, - int[] depthLevels, - int[] wireTypes, - int[] outputTypeIds, - int[] encodings, - boolean[] isRepeated, - boolean[] isRequired, - boolean[] hasDefaultValue, - long[] defaultInts, - double[] defaultFloats, - boolean[] defaultBools, - byte[][] defaultStrings, - int[][] enumValidValues, - boolean failOnErrors) { - return decodeToStruct(binaryInput, fieldNumbers, parentIndices, depthLevels, wireTypes, - outputTypeIds, encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, - defaultFloats, defaultBools, defaultStrings, enumValidValues, - new byte[fieldNumbers.length][][], failOnErrors); - } - private static native long decodeToStruct(long binaryInputView, int[] fieldNumbers, int[] parentIndices, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 513ead2c37..69f7e76b7d 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -22,26 +22,31 @@ * *

    Use this class instead of passing 15+ individual arrays through the JNI boundary. * Validation is performed once in the constructor. + * + *

    The arrays are intentionally exposed as package-private (not public) to allow + * zero-copy access from {@link Protobuf} within the same package, while preventing + * external code from mutating the contents after construction. Callers outside this + * package should treat instances as opaque and immutable. */ public final class ProtobufSchemaDescriptor implements java.io.Serializable { private static final long serialVersionUID = 1L; private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; - public final int[] fieldNumbers; - public final int[] parentIndices; - public final int[] depthLevels; - public final int[] wireTypes; - public final int[] outputTypeIds; - public final int[] encodings; - public final boolean[] isRepeated; - public final boolean[] isRequired; - public final boolean[] hasDefaultValue; - public final long[] defaultInts; - public final double[] defaultFloats; - public final boolean[] defaultBools; - public final byte[][] defaultStrings; - public final int[][] enumValidValues; - public final byte[][][] enumNames; + final int[] fieldNumbers; + final int[] parentIndices; + final int[] depthLevels; + final int[] wireTypes; + final int[] outputTypeIds; + final int[] encodings; + final boolean[] isRepeated; + final boolean[] isRequired; + final boolean[] hasDefaultValue; + final long[] defaultInts; + final double[] defaultFloats; + final boolean[] defaultBools; + final byte[][] defaultStrings; + final int[][] enumValidValues; + final byte[][][] enumNames; /** * @throws IllegalArgumentException if any array is null, arrays have mismatched lengths, @@ -94,6 +99,16 @@ public ProtobufSchemaDescriptor( throw new IllegalArgumentException( "Invalid encoding at index " + i + ": " + enc); } + if (enumValidValues[i] != null) { + int[] ev = enumValidValues[i]; + for (int j = 1; j < ev.length; j++) { + if (ev[j] < ev[j - 1]) { + throw new IllegalArgumentException( + "enumValidValues[" + i + "] must be sorted in ascending order " + + "(binary search requires it), but found " + ev[j - 1] + " before " + ev[j]); + } + } + } } this.fieldNumbers = fieldNumbers; diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 6f65e71d55..a1d241721a 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -28,6 +28,8 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; import java.io.ByteArrayOutputStream; import java.nio.ByteBuffer; @@ -168,6 +170,41 @@ private static int getWireType(int cudfTypeId, int encoding) { return Protobuf.WT_VARINT; } + /** + * Test-only convenience: wrap raw parallel arrays into a ProtobufSchemaDescriptor + * and decode. Avoids verbose ProtobufSchemaDescriptor construction at every call site. + */ + private static ColumnVector decodeRaw(ColumnView binaryInput, + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, + boolean[] hasDefaultValue, long[] defaultInts, + double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, + boolean failOnErrors) { + return decodeRaw(binaryInput, fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, new byte[fieldNumbers.length][][], failOnErrors); + } + + private static ColumnVector decodeRaw(ColumnView binaryInput, + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, + boolean[] hasDefaultValue, long[] defaultInts, + double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, + byte[][][] enumNames, + boolean failOnErrors) { + return Protobuf.decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, enumNames), + failOnErrors); + } + /** * Helper to decode all scalar fields using the unified API. * Builds a flat schema (parentIndices=-1, depth=0, isRepeated=false for all fields). @@ -197,9 +234,12 @@ private static ColumnVector decodeScalarFields(ColumnView binaryInput, wireTypes[i] = getWireType(typeIds[i], encodings[i]); } - return Protobuf.decodeToStruct(binaryInput, fieldNumbers, parentIndices, depthLevels, - wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefaultValue, - defaultInts, defaultFloats, defaultBools, defaultStrings, enumValidValues, failOnErrors); + return Protobuf.decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, defaultStrings, enumValidValues, + new byte[fieldNumbers.length][][]), + failOnErrors); } /** @@ -1663,7 +1703,7 @@ void testUnpackedRepeatedInt32() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { // Use the new nested API for repeated fields // Field: ids (field_number=1, parent=-1, depth=0, wire_type=VARINT, type=INT32, repeated=true) - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, // fieldNumbers new int[]{-1}, // parentIndices (-1 = top level) @@ -1737,7 +1777,7 @@ void testPackedRepeatedDoubleWithMultipleFields() { box(tag(3, WT_LEN)), box(encodeVarint(r2Doubles.length)), box(r2Doubles)); try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, row1, row2}).build()) { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1, 2, 3, 4}, new int[]{-1, -1, -1, -1}, @@ -1811,7 +1851,7 @@ void testNestedMessage() { // Flattened schema: // [0] inner: STRUCT, field_number=1, parent=-1, depth=0 // [1] inner.x: INT32, field_number=1, parent=0, depth=1 - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1, 1}, // fieldNumbers new int[]{-1, 0}, // parentIndices @@ -1859,7 +1899,7 @@ void testDeepNestedMessageDepth3() { ColumnVector expectedMiddle = ColumnVector.makeStruct(expectedInner, expectedM); ColumnVector expectedScore = ColumnVector.fromBoxedFloats(1.25f); ColumnVector expectedStruct = ColumnVector.makeStruct(expectedMiddle, expectedScore); - ColumnVector actualStruct = Protobuf.decodeToStruct( + ColumnVector actualStruct = decodeRaw( input.getColumn(0), new int[]{1, 1, 1, 2, 3, 2, 2}, // fieldNumbers new int[]{-1, 0, 1, 1, 1, 0, -1}, // parentIndices @@ -1901,7 +1941,7 @@ void testPackedRepeatedInsideNestedMessage() { inner); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1, 1}, // outer.inner, inner.ids new int[]{-1, 0}, @@ -1939,7 +1979,7 @@ void testRepeatedUint32() { box(tag(1, WT_VARINT)), box(encodeVarint(3))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -1972,7 +2012,7 @@ void testRepeatedUint64() { box(tag(1, WT_VARINT)), box(encodeVarint(33))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -2006,7 +2046,7 @@ void testWireTypeMismatchInRepeatedMessageChildFailfast() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector ignored = Protobuf.decodeToStruct( + try (ColumnVector ignored = decodeRaw( input.getColumn(0), new int[]{1, 1}, new int[]{-1, 0}, @@ -2253,23 +2293,12 @@ private static ColumnVector decodeAllFieldsWithEnumStrings(ColumnView binaryInpu boolean[] isRepeated = new boolean[numFields]; java.util.Arrays.fill(parentIndices, -1); java.util.Arrays.fill(wireTypes, Protobuf.WT_VARINT); - return Protobuf.decodeToStruct( - binaryInput, - fieldNumbers, - parentIndices, - depthLevels, - wireTypes, - typeIds, - encodings, - isRepeated, - new boolean[numFields], - new boolean[numFields], - new long[numFields], - new double[numFields], - new boolean[numFields], - new byte[numFields][], - enumValidValues, - enumNames, + return Protobuf.decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, + new boolean[numFields], new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + enumValidValues, enumNames), failOnErrors); } @@ -2318,8 +2347,8 @@ void testEnumAsStringUnknownValueReturnsNullRow() { enumNames, false); HostColumnVector hostStruct = actual.copyToHost()) { - assert actual.getNullCount() == 1 : "Struct row should be null for unknown enum value"; - assert hostStruct.isNull(0) : "Row 0 should be null"; + assertEquals(1, actual.getNullCount(), "Struct row should be null for unknown enum value"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null"); } } @@ -2344,10 +2373,10 @@ void testEnumAsStringMixedValidAndUnknown() { enumNames, false); HostColumnVector hostStruct = actual.copyToHost()) { - assert actual.getNullCount() == 1 : "Only the unknown enum row should be null"; - assert !hostStruct.isNull(0) : "Row 0 should be valid"; - assert hostStruct.isNull(1) : "Row 1 should be null"; - assert !hostStruct.isNull(2) : "Row 2 should be valid"; + assertEquals(1, actual.getNullCount(), "Only the unknown enum row should be null"); + assertTrue(!hostStruct.isNull(0), "Row 0 should be valid"); + assertTrue(hostStruct.isNull(1), "Row 1 should be null"); + assertTrue(!hostStruct.isNull(2), "Row 2 should be valid"); } } @@ -2390,8 +2419,8 @@ void testEnumUnknownValueReturnsNullRow() { false); HostColumnVector hostStruct = actualStruct.copyToHost()) { // The struct itself should be null (not just the field) - assert actualStruct.getNullCount() == 1 : "Struct row should be null for unknown enum"; - assert hostStruct.isNull(0) : "Row 0 should be null"; + assertEquals(1, actualStruct.getNullCount(), "Struct row should be null for unknown enum"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null"); } } @@ -2414,11 +2443,11 @@ void testEnumMixedValidAndUnknown() { false); HostColumnVector hostStruct = actualStruct.copyToHost()) { // Check struct-level nulls - assert actualStruct.getNullCount() == 2 : "Should have 2 null rows (rows 1 and 3)"; - assert !hostStruct.isNull(0) : "Row 0 should be valid"; - assert hostStruct.isNull(1) : "Row 1 should be null (unknown enum 999)"; - assert !hostStruct.isNull(2) : "Row 2 should be valid"; - assert hostStruct.isNull(3) : "Row 3 should be null (unknown enum -1)"; + assertEquals(2, actualStruct.getNullCount(), "Should have 2 null rows (rows 1 and 3)"); + assertTrue(!hostStruct.isNull(0), "Row 0 should be valid"); + assertTrue(hostStruct.isNull(1), "Row 1 should be null (unknown enum 999)"); + assertTrue(!hostStruct.isNull(2), "Row 2 should be valid"); + assertTrue(hostStruct.isNull(3), "Row 3 should be null (unknown enum -1)"); } } @@ -2441,8 +2470,8 @@ void testEnumWithOtherFields_NullsEntireRow() { false); HostColumnVector hostStruct = actualStruct.copyToHost()) { // The entire struct row should be null - assert actualStruct.getNullCount() == 1 : "Struct row should be null"; - assert hostStruct.isNull(0) : "Row 0 should be null due to unknown enum"; + assertEquals(1, actualStruct.getNullCount(), "Struct row should be null"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null due to unknown enum"); } } @@ -2463,7 +2492,7 @@ void testEnumMissingFieldDoesNotNullRow() { new int[][]{{0, 1, 2}}, // valid enum values false)) { // Struct row should be valid (not null), only the field is null - assert actualStruct.getNullCount() == 0 : "Struct row should NOT be null for missing field"; + assertEquals(0, actualStruct.getNullCount(), "Struct row should NOT be null for missing field"); AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); } } @@ -2488,7 +2517,7 @@ void testEnumValidWithOtherFields() { new int[][]{{0, 1, 2}, null}, // first field is enum, second is regular int false)) { // Struct row should be valid with correct values - assert actualStruct.getNullCount() == 0 : "Struct row should be valid"; + assertEquals(0, actualStruct.getNullCount(), "Struct row should be valid"); AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); } } @@ -2514,7 +2543,7 @@ void testRepeatedEnumAsString() { } }; try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector actual = Protobuf.decodeToStruct( + ColumnVector actual = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -2534,8 +2563,17 @@ void testRepeatedEnumAsString() { false)) { assertNotNull(actual); assertEquals(DType.STRUCT, actual.getType()); - // The struct has 1 child: a LIST column with ["RED", "BLUE", "GREEN"] assertEquals(1, actual.getNumChildren()); + try (ColumnView listCol = actual.getChildColumnView(0)) { + assertEquals(DType.LIST, listCol.getType()); + try (ColumnView strChild = listCol.getChildColumnView(0); + HostColumnVector hostStrs = strChild.copyToHost()) { + assertEquals(3, hostStrs.getRowCount()); + assertEquals("RED", hostStrs.getJavaString(0)); + assertEquals("BLUE", hostStrs.getJavaString(1)); + assertEquals("GREEN", hostStrs.getJavaString(2)); + } + } } } @@ -2553,7 +2591,7 @@ void testPackedFixedMisaligned() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { assertThrows(RuntimeException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -2585,7 +2623,7 @@ void testPackedFixedMisaligned64() { try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { assertThrows(RuntimeException.class, () -> { - try (ColumnVector result = Protobuf.decodeToStruct( + try (ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -2617,7 +2655,7 @@ void testLargeRepeatedField() throws Exception { Byte[] row = box(baos.toByteArray()); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -2651,7 +2689,7 @@ void testMixedPackedUnpacked() { box(tag(1, WT_LEN)), box(encodeVarint(packedContent.length)), box(packedContent)); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, @@ -2690,7 +2728,7 @@ void testLargeFieldNumber() { box(encodeVarint(42))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{maxFieldNumber}, new int[]{-1}, @@ -2765,7 +2803,7 @@ private void verifyDeepNesting(int numLevels) { } try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, encodings, isRepeated, isRequired, @@ -2793,7 +2831,7 @@ void testZeroLengthNestedMessage() { box(encodeVarint(0))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1, 1}, new int[]{-1, 0}, @@ -2822,7 +2860,7 @@ void testEmptyPackedRepeated() { box(encodeVarint(0))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct( + ColumnVector result = decodeRaw( input.getColumn(0), new int[]{1}, new int[]{-1}, From 661f0853d9c6006817d4b4d8811ab172be3faaa0 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 3 Mar 2026 16:37:53 +0800 Subject: [PATCH 038/107] codex review and address Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index af6f9c6813..36220605f7 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -180,6 +180,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, default_string_values.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + env->DeleteLocalRef(byte_arr); } } @@ -197,6 +198,7 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, if (ints == nullptr) { return 0; } enum_values.emplace_back(ints, ints + len); env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + env->DeleteLocalRef(int_arr); } } @@ -226,9 +228,11 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, names_for_field.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(name_bytes, bytes, JNI_ABORT); + env->DeleteLocalRef(name_bytes); } } enum_name_values.push_back(std::move(names_for_field)); + env->DeleteLocalRef(names_arr); } } From 8e09a47230a5b87ffce12ca6dd586972feb4e813 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 3 Mar 2026 18:20:29 +0800 Subject: [PATCH 039/107] gemini review and address Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 12 +++++++++++- src/main/cpp/src/protobuf_builders.cu | 3 +-- src/main/cpp/src/protobuf_common.cuh | 4 ++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 6293e5d299..357e240181 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -459,8 +459,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences.data(), d_error.data()); + check_error_and_throw(); + // Build the appropriate column type based on element type - // For now, support scalar repeated fields auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); // The output_type in schema is the LIST type, but we need element type @@ -477,6 +478,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -491,6 +493,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -505,6 +508,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -519,6 +523,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -533,6 +538,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -547,6 +553,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -561,6 +568,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, + d_error, stream, mr); break; @@ -596,6 +604,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& total_count, num_rows, false, + d_error, stream, mr); } @@ -612,6 +621,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& total_count, num_rows, true, + d_error, stream, mr); break; diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 50dc9e6617..5e8c6fc9da 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -479,6 +479,7 @@ std::unique_ptr build_repeated_string_column( int total_count, int num_rows, bool is_bytes, + rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { @@ -541,8 +542,6 @@ std::unique_ptr build_repeated_string_column( str_lengths.begin(), str_lengths.end(), stream, mr); rmm::device_uvector chars(total_chars, stream, mr); - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); if (total_chars > 0) { RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; copy_varlen_data_kernel diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index eb99e26e98..ed94df358b 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1021,6 +1021,7 @@ std::unique_ptr build_repeated_string_column( int total_count, int num_rows, bool is_bytes, + rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); @@ -1339,6 +1340,7 @@ inline std::unique_ptr build_repeated_scalar_column( rmm::device_uvector& d_occurrences, int total_count, int num_rows, + rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { @@ -1391,8 +1393,6 @@ inline std::unique_ptr build_repeated_scalar_column( stream.value())); rmm::device_uvector values(total_count, stream, mr); - rmm::device_uvector d_error(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; From 3789cfc5db5ba67121404fa626b919decd5d173a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 10:56:56 +0800 Subject: [PATCH 040/107] adapt new api Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 22 ++----- src/main/cpp/src/protobuf_builders.cu | 93 +++++++-------------------- src/main/cpp/src/protobuf_common.cuh | 28 ++------ 3 files changed, 36 insertions(+), 107 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 357e240181..953083a676 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -62,13 +62,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); auto empty_struct = make_empty_struct_column_with_schema( schema, schema_output_types, i, num_fields, stream, mr); - empty_children.push_back(cudf::make_lists_column(0, - std::move(offsets_col), - std::move(empty_struct), - 0, - rmm::device_buffer{}, - stream, - mr)); + empty_children.push_back(cudf::make_lists_column( + 0, std::move(offsets_col), std::move(empty_struct), 0, rmm::device_buffer{})); } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { // Non-repeated nested message field empty_children.push_back(make_empty_struct_column_with_schema( @@ -694,17 +689,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& std::move(offsets_col), std::move(child_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } else { - column_map[schema_idx] = cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } } } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 5e8c6fc9da..3b1c43d369 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -99,13 +99,8 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( rmm::device_buffer(d_data.data(), total_data, stream, mr), rmm::device_buffer{}, 0); - return cudf::make_lists_column(total_count, - std::move(offsets_col), - std::move(bytes_child), - null_count, - std::move(mask), - stream, - mr); + return cudf::make_lists_column( + total_count, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); } return cudf::make_strings_column( @@ -190,7 +185,7 @@ std::unique_ptr make_empty_column_safe(cudf::data_type dtype, auto child_col = std::make_unique( cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); return cudf::make_lists_column( - 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + 0, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } case cudf::type_id::STRUCT: { // Create empty struct column with no children @@ -222,13 +217,8 @@ std::unique_ptr make_null_list_column_with_child( rmm::device_buffer{}, 0); auto null_mask = cudf::create_null_mask(num_rows, cudf::mask_state::ALL_NULL, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - num_rows, - std::move(null_mask), - stream, - mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), num_rows, std::move(null_mask)); } /** @@ -250,7 +240,7 @@ std::unique_ptr make_empty_list_column(std::unique_ptr build_enum_string_column( @@ -460,12 +450,10 @@ std::unique_ptr build_repeated_enum_string_column( std::move(list_offs_col), std::move(child_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } return cudf::make_lists_column( - num_rows, std::move(list_offs_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + num_rows, std::move(list_offs_col), std::move(child_col), 0, rmm::device_buffer{}); } std::unique_ptr build_repeated_string_column( @@ -505,18 +493,11 @@ std::unique_ptr build_repeated_string_column( std::move(offsets_col), std::move(child_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } else { // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } } @@ -561,13 +542,8 @@ std::unique_ptr build_repeated_string_column( rmm::device_buffer(chars.data(), total_chars, stream, mr), rmm::device_buffer{}, 0); - child_col = cudf::make_lists_column(total_count, - std::move(str_offsets_col), - std::move(bytes_child), - 0, - rmm::device_buffer{}, - stream, - mr); + child_col = cudf::make_lists_column( + total_count, std::move(str_offsets_col), std::move(bytes_child), 0, rmm::device_buffer{}); } else { child_col = cudf::make_strings_column( total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); @@ -588,13 +564,11 @@ std::unique_ptr build_repeated_string_column( std::move(offsets_col), std::move(child_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } // Forward declaration -- build_nested_struct_column is defined after build_repeated_struct_column @@ -710,17 +684,10 @@ std::unique_ptr build_repeated_struct_column( std::move(offsets_col), std::move(empty_struct), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } else { - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(empty_struct), - 0, - rmm::device_buffer{}, - stream, - mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(empty_struct), 0, rmm::device_buffer{}); } } @@ -982,13 +949,11 @@ std::unique_ptr build_repeated_struct_column( std::move(offsets_col), std::move(struct_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}, stream, mr); + num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}); } std::unique_ptr build_nested_struct_column( @@ -1387,13 +1352,8 @@ std::unique_ptr build_repeated_child_list_column( } else { child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); } - return cudf::make_lists_column(num_parent_rows, - std::move(list_offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); + return cudf::make_lists_column( + num_parent_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); @@ -1514,13 +1474,8 @@ std::unique_ptr build_repeated_child_list_column( list_offs.release(), rmm::device_buffer{}, 0); - return cudf::make_lists_column(num_parent_rows, - std::move(list_offs_col), - std::move(child_values), - 0, - rmm::device_buffer{}, - stream, - mr); + return cudf::make_lists_column( + num_parent_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}); } } // namespace spark_rapids_jni::protobuf_detail diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index ed94df358b..bc9b1dc86f 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1152,13 +1152,8 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( rmm::device_buffer(chars.data(), total_size, stream, mr), rmm::device_buffer{}, 0); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(bytes_child), - null_count, - std::move(mask), - stream, - mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); } return cudf::make_strings_column( @@ -1367,18 +1362,11 @@ inline std::unique_ptr build_repeated_scalar_column( std::move(offsets_col), std::move(child_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } else { // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - 0, - rmm::device_buffer{}, - stream, - mr); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } } @@ -1447,13 +1435,11 @@ inline std::unique_ptr build_repeated_scalar_column( std::move(offsets_col), std::move(child_col), input_null_count, - std::move(null_mask), - stream, - mr); + std::move(null_mask)); } return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}, stream, mr); + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } } // namespace spark_rapids_jni::protobuf_detail From 33ddacd723052f250cfc99d078ed423b51b3559e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 14:16:32 +0800 Subject: [PATCH 041/107] add micro benchmark Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/CMakeLists.txt | 3 + src/main/cpp/benchmarks/protobuf_decode.cu | 546 +++++++++++++++++++++ 2 files changed, 549 insertions(+) create mode 100644 src/main/cpp/benchmarks/protobuf_decode.cu diff --git a/src/main/cpp/benchmarks/CMakeLists.txt b/src/main/cpp/benchmarks/CMakeLists.txt index 7641552622..00b39be8e1 100644 --- a/src/main/cpp/benchmarks/CMakeLists.txt +++ b/src/main/cpp/benchmarks/CMakeLists.txt @@ -86,3 +86,6 @@ ConfigureBench(GET_JSON_OBJECT_BENCH ConfigureBench(PARSE_URI_BENCH parse_uri.cpp) + +ConfigureBench(PROTOBUF_DECODE_BENCH + protobuf_decode.cu) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu new file mode 100644 index 0000000000..05cb5a8833 --- /dev/null +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -0,0 +1,546 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace { + +// --------------------------------------------------------------------------- +// Protobuf wire-format encoding helpers (host side, for generating test data) +// --------------------------------------------------------------------------- + +void encode_varint(std::vector& buf, uint64_t value) +{ + while (value > 0x7F) { + buf.push_back(static_cast((value & 0x7F) | 0x80)); + value >>= 7; + } + buf.push_back(static_cast(value)); +} + +void encode_tag(std::vector& buf, int field_number, int wire_type) +{ + encode_varint(buf, static_cast(field_number << 3 | wire_type)); +} + +void encode_varint_field(std::vector& buf, int field_number, int64_t value) +{ + encode_tag(buf, field_number, /*WT_VARINT=*/0); + encode_varint(buf, static_cast(value)); +} + +void encode_fixed32_field(std::vector& buf, int field_number, float value) +{ + encode_tag(buf, field_number, /*WT_32BIT=*/5); + uint32_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + for (int i = 0; i < 4; i++) { + buf.push_back(static_cast(bits & 0xFF)); + bits >>= 8; + } +} + +void encode_fixed64_field(std::vector& buf, int field_number, double value) +{ + encode_tag(buf, field_number, /*WT_64BIT=*/1); + uint64_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + for (int i = 0; i < 8; i++) { + buf.push_back(static_cast(bits & 0xFF)); + bits >>= 8; + } +} + +void encode_len_field(std::vector& buf, int field_number, void const* data, size_t len) +{ + encode_tag(buf, field_number, /*WT_LEN=*/2); + encode_varint(buf, len); + auto const* p = static_cast(data); + buf.insert(buf.end(), p, p + len); +} + +void encode_string_field(std::vector& buf, int field_number, std::string const& s) +{ + encode_len_field(buf, field_number, s.data(), s.size()); +} + +// Encode a nested message: write its content into a temporary buffer, then emit as WT_LEN. +template +void encode_nested_message(std::vector& buf, int field_number, Fn&& content_fn) +{ + std::vector inner; + content_fn(inner); + encode_len_field(buf, field_number, inner.data(), inner.size()); +} + +// Encode a packed repeated int32 field. +void encode_packed_repeated_int32(std::vector& buf, + int field_number, + std::vector const& values) +{ + std::vector packed; + for (auto v : values) { + encode_varint(packed, static_cast(static_cast(v))); + } + encode_len_field(buf, field_number, packed.data(), packed.size()); +} + +// --------------------------------------------------------------------------- +// Build a cuDF LIST column from host message buffers +// --------------------------------------------------------------------------- + +std::unique_ptr make_binary_column(std::vector> const& messages) +{ + auto stream = cudf::get_default_stream(); + auto mr = rmm::mr::get_current_device_resource(); + + std::vector h_offsets(messages.size() + 1); + h_offsets[0] = 0; + for (size_t i = 0; i < messages.size(); i++) { + h_offsets[i + 1] = h_offsets[i] + static_cast(messages[i].size()); + } + int32_t total_bytes = h_offsets.back(); + + std::vector h_data; + h_data.reserve(total_bytes); + for (auto const& m : messages) { + h_data.insert(h_data.end(), m.begin(), m.end()); + } + + rmm::device_buffer d_data(h_data.data(), h_data.size(), stream, mr); + rmm::device_buffer d_offsets(h_offsets.data(), h_offsets.size() * sizeof(int32_t), stream, mr); + stream.synchronize(); + + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, + total_bytes, + std::move(d_data), + rmm::device_buffer{}, + 0); + auto offsets_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, + static_cast(h_offsets.size()), + std::move(d_offsets), + rmm::device_buffer{}, + 0); + + return cudf::make_lists_column(static_cast(messages.size()), + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}); +} + +// --------------------------------------------------------------------------- +// Schema + message generators for different benchmark scenarios +// --------------------------------------------------------------------------- + +using nfd = spark_rapids_jni::nested_field_descriptor; + +// Case 1: Flat scalars only — many top-level scalar fields. +// message FlatMessage { +// int32 f1 = 1; +// int64 f2 = 2; +// ... +// float f_k = k; (cycling through int32, int64, float, double, bool) +// string s_k+1 = k+1; (a few string fields) +// } +struct FlatScalarCase { + int num_int_fields; + int num_string_fields; + + spark_rapids_jni::ProtobufDecodeContext build_context() const + { + spark_rapids_jni::ProtobufDecodeContext ctx; + ctx.fail_on_errors = true; + + // type_id cycle for integer-like fields + cudf::type_id int_types[] = {cudf::type_id::INT32, + cudf::type_id::INT64, + cudf::type_id::FLOAT32, + cudf::type_id::FLOAT64, + cudf::type_id::BOOL8}; + int wt_for_type[] = {0 /*WT_VARINT*/, 0, 5 /*WT_32BIT*/, 1 /*WT_64BIT*/, 0}; + + int fn = 1; + for (int i = 0; i < num_int_fields; i++, fn++) { + int ti = i % 5; + auto ty = int_types[ti]; + int wt = wt_for_type[ti]; + int enc = spark_rapids_jni::ENC_DEFAULT; + if (ty == cudf::type_id::FLOAT32) enc = spark_rapids_jni::ENC_FIXED; + if (ty == cudf::type_id::FLOAT64) enc = spark_rapids_jni::ENC_FIXED; + ctx.schema.push_back({fn, -1, 0, wt, ty, enc, false, false, false}); + } + for (int i = 0; i < num_string_fields; i++, fn++) { + ctx.schema.push_back( + {fn, -1, 0, 2 /*WT_LEN*/, cudf::type_id::STRING, 0, false, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(5, 50); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + int fn = 1; + for (int i = 0; i < num_int_fields; i++, fn++) { + int ti = i % 5; + switch (ti) { + case 0: encode_varint_field(buf, fn, int_dist(rng)); break; + case 1: encode_varint_field(buf, fn, int_dist(rng)); break; + case 2: encode_fixed32_field(buf, fn, static_cast(int_dist(rng))); break; + case 3: encode_fixed64_field(buf, fn, static_cast(int_dist(rng))); break; + case 4: encode_varint_field(buf, fn, rng() % 2); break; + } + } + for (int i = 0; i < num_string_fields; i++, fn++) { + int len = str_len_dist(rng); + std::string s(len, ' '); + for (int c = 0; c < len; c++) { + s[c] = alphabet[rng() % alphabet.size()]; + } + encode_string_field(buf, fn, s); + } + } + return messages; + } +}; + +// Case 2: Nested message — a top-level message with a nested struct child. +// message OuterMessage { +// int32 id = 1; +// string name = 2; +// InnerMessage inner = 3; +// } +// message InnerMessage { +// int32 x = 1; +// int64 y = 2; +// string data = 3; +// ... (num_inner_fields fields) +// } +struct NestedMessageCase { + int num_inner_fields; // scalar fields inside InnerMessage + + spark_rapids_jni::ProtobufDecodeContext build_context() const + { + spark_rapids_jni::ProtobufDecodeContext ctx; + ctx.fail_on_errors = true; + + // idx 0: id (int32, top-level) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: name (string, top-level) + ctx.schema.push_back({2, -1, 0, 2, cudf::type_id::STRING, 0, false, false, false}); + // idx 2: inner (STRUCT, top-level) + ctx.schema.push_back({3, -1, 0, 2, cudf::type_id::STRUCT, 0, false, false, false}); + + // Inner message children (parent_idx=2, depth=1) + cudf::type_id inner_types[] = { + cudf::type_id::INT32, cudf::type_id::INT64, cudf::type_id::STRING}; + int inner_wt[] = {0, 0, 2}; + + for (int i = 0; i < num_inner_fields; i++) { + int ti = i % 3; + ctx.schema.push_back( + {i + 1, 2, 1, inner_wt[ti], inner_types[ti], 0, false, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(5, 30); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + encode_varint_field(buf, 1, int_dist(rng)); + encode_string_field(buf, 2, random_string(str_len_dist(rng))); + + encode_nested_message(buf, 3, [&](std::vector& inner) { + for (int i = 0; i < num_inner_fields; i++) { + int ti = i % 3; + switch (ti) { + case 0: encode_varint_field(inner, i + 1, int_dist(rng)); break; + case 1: encode_varint_field(inner, i + 1, int_dist(rng)); break; + case 2: encode_string_field(inner, i + 1, random_string(str_len_dist(rng))); break; + } + } + }); + } + return messages; + } +}; + +// Case 3: Repeated fields — top-level repeated scalars and a repeated nested message. +// message RepeatedMessage { +// int32 id = 1; +// repeated int32 tags = 2; +// repeated string labels = 3; +// repeated Item items = 4; +// } +// message Item { +// int32 item_id = 1; +// string item_name = 2; +// int64 value = 3; +// } +struct RepeatedFieldCase { + int avg_tags_per_row; + int avg_labels_per_row; + int avg_items_per_row; + + spark_rapids_jni::ProtobufDecodeContext build_context() const + { + spark_rapids_jni::ProtobufDecodeContext ctx; + ctx.fail_on_errors = true; + + // idx 0: id (int32, scalar) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: tags (repeated int32, packed) + ctx.schema.push_back({2, -1, 0, 0, cudf::type_id::INT32, 0, true, false, false}); + // idx 2: labels (repeated string) + ctx.schema.push_back({3, -1, 0, 2, cudf::type_id::STRING, 0, true, false, false}); + // idx 3: items (repeated STRUCT) + ctx.schema.push_back({4, -1, 0, 2, cudf::type_id::STRUCT, 0, true, false, false}); + // idx 4: Item.item_id (int32, child of idx 3) + ctx.schema.push_back({1, 3, 1, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 5: Item.item_name (string, child of idx 3) + ctx.schema.push_back({2, 3, 1, 2, cudf::type_id::STRING, 0, false, false, false}); + // idx 6: Item.value (int64, child of idx 3) + ctx.schema.push_back({3, 3, 1, 0, cudf::type_id::INT64, 0, false, false, false}); + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(3, 20); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + // Vary count per row around the average (±50%) + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + + // id + encode_varint_field(buf, 1, int_dist(rng)); + + // tags (packed repeated int32) + { + int n = vary(avg_tags_per_row); + std::vector tags(n); + for (auto& t : tags) t = int_dist(rng); + if (n > 0) encode_packed_repeated_int32(buf, 2, tags); + } + + // labels (unpacked repeated string) + { + int n = vary(avg_labels_per_row); + for (int i = 0; i < n; i++) { + encode_string_field(buf, 3, random_string(str_len_dist(rng))); + } + } + + // items (repeated nested message) + { + int n = vary(avg_items_per_row); + for (int i = 0; i < n; i++) { + encode_nested_message(buf, 4, [&](std::vector& inner) { + encode_varint_field(inner, 1, int_dist(rng)); + encode_string_field(inner, 2, random_string(str_len_dist(rng))); + encode_varint_field(inner, 3, int_dist(rng)); + }); + } + } + } + return messages; + } +}; + +} // anonymous namespace + +// =========================================================================== +// Benchmark 1: Flat scalars — measures per-field extraction overhead +// =========================================================================== +static void BM_protobuf_flat_scalars(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_fields = static_cast(state.get_int64("num_fields")); + int const num_str = std::max(1, num_fields / 10); + int const num_int = num_fields - num_str; + + FlatScalarCase flat_case{num_int, num_str}; + auto ctx = flat_case.build_context(); + + std::mt19937 rng(42); + auto messages = flat_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_flat_scalars) + .set_name("Protobuf Flat Scalars") + .add_int64_axis("num_rows", {10'000, 100'000, 500'000}) + .add_int64_axis("num_fields", {10, 50, 200}); + +// =========================================================================== +// Benchmark 2: Nested messages — measures nested struct build overhead +// =========================================================================== +static void BM_protobuf_nested(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const inner_fields = static_cast(state.get_int64("inner_fields")); + + NestedMessageCase nested_case{inner_fields}; + auto ctx = nested_case.build_context(); + + std::mt19937 rng(42); + auto messages = nested_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_nested) + .set_name("Protobuf Nested Message") + .add_int64_axis("num_rows", {10'000, 100'000, 500'000}) + .add_int64_axis("inner_fields", {5, 20, 100}); + +// =========================================================================== +// Benchmark 3: Repeated fields — measures repeated field pipeline overhead +// =========================================================================== +static void BM_protobuf_repeated(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const avg_items = static_cast(state.get_int64("avg_items")); + + RepeatedFieldCase rep_case{/*avg_tags=*/5, /*avg_labels=*/3, /*avg_items=*/avg_items}; + auto ctx = rep_case.build_context(); + + std::mt19937 rng(42); + auto messages = rep_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated) + .set_name("Protobuf Repeated Fields") + .add_int64_axis("num_rows", {10'000, 100'000}) + .add_int64_axis("avg_items", {1, 5, 20}); From c56ed5b4b423863e866cc71175a86722250cc61e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 15:37:30 +0800 Subject: [PATCH 042/107] merge reduce+scan, defer error check, combined occ scan Signed-off-by: Haoyang Li --- perf-results/RESULTS.md | 222 ++++++++++++++++++ perf-results/run_00_baseline.txt | 115 +++++++++ .../run_00_baseline_many_repeated.txt | 47 ++++ .../run_01_p0_batch_repeated_sync.txt | 118 ++++++++++ perf-results/run_01_p0_many_repeated.txt | 48 ++++ perf-results/run_01_p0_merge_reduce_scan.txt | 117 +++++++++ perf-results/run_02_p3_defer_error_check.txt | 139 +++++++++++ .../run_03_p1_combined_occurrence_scan.txt | 143 +++++++++++ src/main/cpp/benchmarks/protobuf_decode.cu | 201 +++++++++++++--- src/main/cpp/src/protobuf.cu | 128 ++++++---- src/main/cpp/src/protobuf_common.cuh | 18 ++ src/main/cpp/src/protobuf_kernels.cu | 73 ++++++ 12 files changed, 1284 insertions(+), 85 deletions(-) create mode 100644 perf-results/RESULTS.md create mode 100644 perf-results/run_00_baseline.txt create mode 100644 perf-results/run_00_baseline_many_repeated.txt create mode 100644 perf-results/run_01_p0_batch_repeated_sync.txt create mode 100644 perf-results/run_01_p0_many_repeated.txt create mode 100644 perf-results/run_01_p0_merge_reduce_scan.txt create mode 100644 perf-results/run_02_p3_defer_error_check.txt create mode 100644 perf-results/run_03_p1_combined_occurrence_scan.txt diff --git a/perf-results/RESULTS.md b/perf-results/RESULTS.md new file mode 100644 index 0000000000..b3a3864aa3 --- /dev/null +++ b/perf-results/RESULTS.md @@ -0,0 +1,222 @@ +# Protobuf GPU Decoder Performance Optimization Log + +**GPU**: NVIDIA RTX 5880 Ada Generation (110 SMs, 48GB) +**Benchmark**: `PROTOBUF_DECODE_BENCH` (3 cases: Flat Scalars, Nested Message, Repeated Fields) + +--- + +## 00. Baseline + +**日期**: 2026-03-04 +**改动摘要**: 无改动,记录初始性能基准 +**原始输出**: `run_00_baseline.txt` + +### Flat Scalars + +| num_rows | num_fields | GPU Time (ms) | +|----------|------------|---------------| +| 10,000 | 10 | 0.693 | +| 100,000 | 10 | 1.502 | +| 500,000 | 10 | 4.385 | +| 10,000 | 50 | 2.549 | +| 100,000 | 50 | 7.030 | +| 500,000 | 50 | 26.646 | +| 10,000 | 200 | 10.033 | +| 100,000 | 200 | 28.587 | +| 500,000 | 200 | 115.022 | + +### Nested Message + +| num_rows | inner_fields | GPU Time (ms) | +|----------|--------------|---------------| +| 10,000 | 5 | 0.699 | +| 100,000 | 5 | 1.551 | +| 500,000 | 5 | 4.787 | +| 10,000 | 20 | 2.445 | +| 100,000 | 20 | 4.794 | +| 500,000 | 20 | 13.995 | +| 10,000 | 100 | 10.549 | +| 100,000 | 100 | 23.539 | +| 500,000 | 100 | 68.828 | + +### Repeated Fields + +| num_rows | avg_items | GPU Time (ms) | +|----------|-----------|---------------| +| 10,000 | 1 | 0.966 | +| 100,000 | 1 | 2.362 | +| 10,000 | 5 | 1.292 | +| 100,000 | 5 | 3.999 | +| 10,000 | 20 | 2.221 | +| 100,000 | 20 | 7.696 | + +--- + +## 01. P0: Merge thrust::reduce + exclusive_scan into inclusive_scan + +**日期**: 2026-03-04 +**改动摘要**: 替换每个 repeated 字段的 `thrust::reduce` (implicit sync) + `thrust::exclusive_scan` + H2D copy 为单次 `thrust::inclusive_scan` + D2H copy,消除 CUB reduce 内部开销 +**改动文件**: protobuf.cu +**原始输出**: `run_01_p0_merge_reduce_scan.txt` + +### Flat Scalars + +| num_rows | num_fields | Before (ms) | After (ms) | Speedup | +|----------|------------|-------------|------------|---------| +| 10,000 | 10 | 0.693 | 0.727 | - | +| 100,000 | 10 | 1.502 | 1.537 | - | +| 500,000 | 10 | 4.385 | 4.444 | - | +| 10,000 | 50 | 2.549 | 2.727 | - | +| 100,000 | 50 | 7.030 | 7.246 | - | +| 500,000 | 50 | 26.646 | 26.963 | - | +| 10,000 | 200 | 10.033 | 10.912 | - | +| 100,000 | 200 | 28.587 | 29.367 | - | +| 500,000 | 200 | 115.022 | 116.532 | - | + +### Nested Message + +| num_rows | inner_fields | Before (ms) | After (ms) | Speedup | +|----------|--------------|-------------|------------|---------| +| 10,000 | 5 | 0.699 | 0.750 | - | +| 100,000 | 5 | 1.551 | 1.639 | - | +| 500,000 | 5 | 4.787 | 4.810 | - | +| 10,000 | 20 | 2.445 | 2.630 | - | +| 100,000 | 20 | 4.794 | 5.106 | - | +| 500,000 | 20 | 13.995 | 14.065 | - | +| 10,000 | 100 | 10.549 | 11.389 | - | +| 100,000 | 100 | 23.539 | 24.114 | - | +| 500,000 | 100 | 68.828 | 69.461 | - | + +### Repeated Fields + +| num_rows | avg_items | Before (ms) | After (ms) | Speedup | +|----------|-----------|-------------|------------|---------| +| 10,000 | 1 | 0.966 | 0.945 | 1.02x | +| 100,000 | 1 | 2.362 | 2.309 | 1.02x | +| 10,000 | 5 | 1.292 | 1.271 | 1.02x | +| 100,000 | 5 | 3.999 | 3.967 | 1.01x | +| 10,000 | 20 | 2.221 | 2.220 | 1.00x | +| 100,000 | 20 | 7.696 | 7.729 | 1.00x | + +### Many Repeated Fields (新增 case,更接近客户场景) + +| num_rows | num_rep_fields | Before (ms) | After (ms) | Speedup | +|----------|----------------|-------------|------------|---------| +| 10,000 | 10 | 1.891 | 1.748 | **1.08x** | +| 100,000 | 10 | 5.595 | 5.443 | 1.03x | +| 10,000 | 30 | 5.937 | 5.375 | **1.10x** | +| 100,000 | 30 | 18.947 | 18.527 | 1.02x | +| 10,000 | 50 | 9.851 | 9.056 | **1.09x** | +| 100,000 | 50 | 36.678 | 35.970 | 1.02x | + +**结论**: Flat/Nested 无变化(符合预期)。Repeated 3 字段 case 微小提升 (~1-2%)。**Many Repeated Fields case 在小行数下有 8-10% 提升**,大行数下 2-3% 提升——因为小行数时 sync 开销占比更大。对于客户的 ~98 个 repeated 字段 + ~13.5K 行/批的场景,此优化的收益在 small batch 区间(~10%)。 + +**UT**: 95/95 全过。 + +--- + +## 02. P3: Defer check_error_and_throw to end of decode + +**日期**: 2026-03-04 +**改动摘要**: 删除 4 处中间 `check_error_and_throw()` 调用(每次含 D2H copy + stream.synchronize),只保留函数末尾的一次最终 error check。`d_error` flag 在 GPU 上只写不读,中间检查完全冗余。 +**改动文件**: protobuf.cu +**原始输出**: `run_02_p3_defer_error_check.txt` + +### Flat Scalars (对比 baseline) + +| num_rows | num_fields | Baseline (ms) | P0+P3 (ms) | Speedup | +|----------|------------|---------------|------------|---------| +| 10,000 | 10 | 0.693 | 0.693 | 1.00x | +| 100,000 | 10 | 1.502 | 1.491 | 1.01x | +| 500,000 | 10 | 4.385 | 4.212 | **1.04x** | +| 10,000 | 50 | 2.549 | 2.584 | - | +| 100,000 | 50 | 7.030 | 7.038 | 1.00x | +| 500,000 | 50 | 26.646 | 26.500 | 1.01x | +| 10,000 | 200 | 10.033 | 10.245 | - | +| 100,000 | 200 | 28.587 | 28.589 | 1.00x | +| 500,000 | 200 | 115.022 | 115.005 | 1.00x | + +### Nested Message (对比 baseline) + +| num_rows | inner_fields | Baseline (ms) | P0+P3 (ms) | Speedup | +|----------|--------------|---------------|------------|---------| +| 10,000 | 5 | 0.699 | 0.683 | **1.02x** | +| 100,000 | 5 | 1.551 | 1.518 | **1.02x** | +| 500,000 | 5 | 4.787 | 4.584 | **1.04x** | +| 10,000 | 20 | 2.445 | 2.457 | 1.00x | +| 100,000 | 20 | 4.794 | 4.740 | **1.01x** | +| 500,000 | 20 | 13.995 | 13.798 | **1.01x** | +| 10,000 | 100 | 10.549 | 10.623 | 1.00x | +| 100,000 | 100 | 23.539 | 23.489 | 1.00x | +| 500,000 | 100 | 68.828 | 69.207 | 1.00x | + +### Repeated Fields (对比 baseline) + +| num_rows | avg_items | Baseline (ms) | P0+P3 (ms) | Speedup | +|----------|-----------|---------------|------------|---------| +| 10,000 | 1 | 0.966 | 0.839 | **1.15x** | +| 100,000 | 1 | 2.362 | 2.153 | **1.10x** | +| 10,000 | 5 | 1.292 | 1.148 | **1.13x** | +| 100,000 | 5 | 3.999 | 3.803 | **1.05x** | +| 10,000 | 20 | 2.221 | 2.080 | **1.07x** | +| 100,000 | 20 | 7.696 | 7.516 | **1.02x** | + +### Many Repeated Fields (对比 baseline) + +| num_rows | num_rep_fields | Baseline (ms) | P0+P3 (ms) | Speedup | +|----------|----------------|---------------|------------|---------| +| 10,000 | 10 | 1.891 | 1.472 | **1.28x** | +| 100,000 | 10 | 5.595 | 5.011 | **1.12x** | +| 10,000 | 30 | 5.937 | 4.676 | **1.27x** | +| 100,000 | 30 | 18.947 | 17.065 | **1.11x** | +| 10,000 | 50 | 9.851 | 8.137 | **1.21x** | +| 100,000 | 50 | 36.678 | 33.944 | **1.08x** | + +**结论**: P3 效果显著! +- **Many Repeated Fields**: 小行数 21-28% 提升,大行数 8-12% 提升 +- **Repeated Fields (3 fields)**: 7-15% 提升 +- **Flat/Nested**: Flat 几乎不变(无 repeated → 只省了 scan/required 2 次 sync),Nested 小幅提升 1-4% +- 总体来看 P0+P3 组合在 repeated-heavy schema 下带来了可观的加速 + +**UT**: 95/95 全过。 + +--- + +## 03. P1: Combined repeated occurrence scan kernel + +**日期**: 2026-03-04 +**改动摘要**: 新增 `scan_all_repeated_occurrences_kernel`,一次扫描消息就记录所有 repeated 字段的 occurrence(替代原来 N 个 repeated 字段各自独立做一次全消息扫描)。同时将 offset 计算全部前置并用 1 次 sync 代替 N 次。 +**改动文件**: protobuf.cu, protobuf_kernels.cu, protobuf_common.cuh +**原始输出**: `run_03_p1_combined_occurrence_scan.txt` + +### Many Repeated Fields (P1 的核心目标 case) + +| num_rows | num_rep_fields | Baseline (ms) | P0+P3 (ms) | P0+P3+P1 (ms) | vs Baseline | vs P0+P3 | +|----------|----------------|---------------|------------|----------------|-------------|----------| +| 10,000 | 10 | 1.891 | 1.472 | 1.638 | **1.15x** | 0.90x | +| 100,000 | 10 | 5.595 | 5.011 | 5.579 | 1.00x | 0.90x | +| 10,000 | 30 | 5.937 | 4.676 | 4.620 | **1.29x** | **1.01x** | +| 100,000 | 30 | 18.947 | 17.065 | 16.545 | **1.15x** | **1.03x** | +| 10,000 | 50 | 9.851 | 8.137 | 7.958 | **1.24x** | **1.02x** | +| 100,000 | 50 | 36.678 | 33.944 | 27.664 | **1.33x** | **1.23x** | + +### Repeated Fields (3 fields, 对比 baseline) + +| num_rows | avg_items | Baseline (ms) | P0+P3+P1 (ms) | Speedup | +|----------|-----------|---------------|----------------|---------| +| 10,000 | 1 | 0.966 | 0.853 | **1.13x** | +| 100,000 | 1 | 2.362 | 2.282 | **1.04x** | +| 10,000 | 5 | 1.292 | 1.177 | **1.10x** | +| 10,000 | 20 | 2.221 | 2.129 | **1.04x** | + +### Flat Scalars / Nested Message + +不受影响(符合预期,无顶层 repeated 字段)。 + +**结论**: +- **Many Repeated 100K×50: 36.68ms → 27.66ms (vs baseline 1.33x)**,这是三个优化叠加的效果 +- P1 单独贡献(vs P0+P3):100K×50 快了 **23%**(33.94→27.66),100K×30 快了 3% +- P1 的增益在字段数多 + 行数多时最显著:每多一个 repeated 字段,原来多一次全消息扫描,现在都合并为一次 +- 10K rows + 10 fields 时 P1 有微小退化(offset 前置分配的额外开销 > 合并扫描的收益) + +**UT**: 95/95 全过。 diff --git a/perf-results/run_00_baseline.txt b/perf-results/run_00_baseline.txt new file mode 100644 index 0000000000..e9109aa058 --- /dev/null +++ b/perf-results/run_00_baseline.txt @@ -0,0 +1,115 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] +Pass: Cold: 0.693088ms GPU, 0.696810ms CPU, 0.69s total GPU, 0.80s total wall, 992x +Run: [2/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] +Pass: Cold: 1.501602ms GPU, 1.505321ms CPU, 0.50s total GPU, 0.53s total wall, 333x +Run: [3/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] +Pass: Cold: 4.384572ms GPU, 4.388343ms CPU, 11.93s total GPU, 12.19s total wall, 2720x +Run: [4/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] +Pass: Cold: 2.548858ms GPU, 2.552612ms CPU, 5.91s total GPU, 6.15s total wall, 2320x +Run: [5/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] +Pass: Cold: 7.030410ms GPU, 7.034253ms CPU, 9.96s total GPU, 10.08s total wall, 1416x +Run: [6/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] +Pass: Cold: 26.646053ms GPU, 26.650230ms CPU, 0.51s total GPU, 0.51s total wall, 19x +Run: [7/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] +Pass: Cold: 10.032932ms GPU, 10.036624ms CPU, 0.50s total GPU, 0.51s total wall, 50x +Run: [8/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] +Pass: Cold: 28.587383ms GPU, 28.591340ms CPU, 0.51s total GPU, 0.52s total wall, 18x +Run: [9/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] +Pass: Cold: 115.021551ms GPU, 115.026983ms CPU, 1.27s total GPU, 1.27s total wall, 11x +Run: [10/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] +Pass: Cold: 0.698673ms GPU, 0.702294ms CPU, 0.50s total GPU, 0.58s total wall, 716x +Run: [11/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] +Pass: Cold: 1.550901ms GPU, 1.554820ms CPU, 1.39s total GPU, 1.48s total wall, 896x +Run: [12/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] +Pass: Cold: 4.786635ms GPU, 4.790429ms CPU, 13.17s total GPU, 13.44s total wall, 2752x +Run: [13/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] +Pass: Cold: 2.445035ms GPU, 2.448843ms CPU, 0.50s total GPU, 0.52s total wall, 205x +Run: [14/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] +Pass: Cold: 4.794299ms GPU, 4.798064ms CPU, 0.50s total GPU, 0.51s total wall, 105x +Run: [15/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] +Pass: Cold: 13.994625ms GPU, 13.998604ms CPU, 6.38s total GPU, 6.41s total wall, 456x +Run: [16/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] +Pass: Cold: 10.549263ms GPU, 10.553169ms CPU, 11.31s total GPU, 11.42s total wall, 1072x +Run: [17/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] +Warn: Current measurement timed out (15.00s) while over noise threshold (7.07% > 0.50%) +Pass: Cold: 23.539355ms GPU, 23.543439ms CPU, 14.95s total GPU, 15.00s total wall, 635x +Run: [18/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] +Pass: Cold: 68.827659ms GPU, 68.833052ms CPU, 0.76s total GPU, 0.76s total wall, 11x +Run: [19/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] +Pass: Cold: 0.965625ms GPU, 0.969249ms CPU, 1.15s total GPU, 1.29s total wall, 1196x +Run: [20/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] +Pass: Cold: 2.362006ms GPU, 2.365883ms CPU, 2.76s total GPU, 2.88s total wall, 1168x +Run: [21/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] +Pass: Cold: 1.291744ms GPU, 1.295457ms CPU, 1.47s total GPU, 1.59s total wall, 1141x +Run: [22/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] +Pass: Cold: 3.999376ms GPU, 4.003021ms CPU, 8.96s total GPU, 9.18s total wall, 2240x +Run: [23/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] +Pass: Cold: 2.220686ms GPU, 2.224453ms CPU, 0.50s total GPU, 0.52s total wall, 226x +Run: [24/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] +Pass: Cold: 7.696091ms GPU, 7.699906ms CPU, 0.50s total GPU, 0.51s total wall, 65x +``` + +# Benchmark Results + +## Protobuf Flat Scalars + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|------------|---------|------------|-------|------------|-------|--------| +| 10000 | 10 | 992x | 696.810 us | 0.51% | 693.088 us | 0.51% | 10000 | +| 100000 | 10 | 333x | 1.505 ms | 0.38% | 1.502 ms | 0.38% | 100000 | +| 500000 | 10 | 2720x | 4.388 ms | 0.57% | 4.385 ms | 0.57% | 500000 | +| 10000 | 50 | 2320x | 2.553 ms | 0.80% | 2.549 ms | 0.80% | 10000 | +| 100000 | 50 | 1416x | 7.034 ms | 0.50% | 7.030 ms | 0.50% | 100000 | +| 500000 | 50 | 19x | 26.650 ms | 0.15% | 26.646 ms | 0.15% | 500000 | +| 10000 | 200 | 50x | 10.037 ms | 0.17% | 10.033 ms | 0.17% | 10000 | +| 100000 | 200 | 18x | 28.591 ms | 0.22% | 28.587 ms | 0.22% | 100000 | +| 500000 | 200 | 11x | 115.027 ms | 0.11% | 115.022 ms | 0.11% | 500000 | + +## Protobuf Nested Message + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|--------------|---------|------------|-------|------------|-------|--------| +| 10000 | 5 | 716x | 702.294 us | 0.44% | 698.673 us | 0.44% | 10000 | +| 100000 | 5 | 896x | 1.555 ms | 0.82% | 1.551 ms | 0.82% | 100000 | +| 500000 | 5 | 2752x | 4.790 ms | 6.02% | 4.787 ms | 6.02% | 500000 | +| 10000 | 20 | 205x | 2.449 ms | 0.27% | 2.445 ms | 0.27% | 10000 | +| 100000 | 20 | 105x | 4.798 ms | 0.16% | 4.794 ms | 0.16% | 100000 | +| 500000 | 20 | 456x | 13.999 ms | 0.50% | 13.995 ms | 0.50% | 500000 | +| 10000 | 100 | 1072x | 10.553 ms | 2.57% | 10.549 ms | 2.57% | 10000 | +| 100000 | 100 | 635x | 23.543 ms | 7.07% | 23.539 ms | 7.07% | 100000 | +| 500000 | 100 | 11x | 68.833 ms | 0.33% | 68.828 ms | 0.33% | 500000 | + +## Protobuf Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|-----------|---------|------------|-------|------------|-------|--------| +| 10000 | 1 | 1196x | 969.249 us | 0.50% | 965.625 us | 0.50% | 10000 | +| 100000 | 1 | 1168x | 2.366 ms | 0.59% | 2.362 ms | 0.59% | 100000 | +| 10000 | 5 | 1141x | 1.295 ms | 0.50% | 1.292 ms | 0.50% | 10000 | +| 100000 | 5 | 2240x | 4.003 ms | 0.62% | 3.999 ms | 0.62% | 100000 | +| 10000 | 20 | 226x | 2.224 ms | 0.27% | 2.221 ms | 0.27% | 10000 | +| 100000 | 20 | 65x | 7.700 ms | 0.13% | 7.696 ms | 0.13% | 100000 | diff --git a/perf-results/run_00_baseline_many_repeated.txt b/perf-results/run_00_baseline_many_repeated.txt new file mode 100644 index 0000000000..a6b5e09a79 --- /dev/null +++ b/perf-results/run_00_baseline_many_repeated.txt @@ -0,0 +1,47 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] +Pass: Cold: 1.891084ms GPU, 1.894789ms CPU, 1.03s total GPU, 1.09s total wall, 544x +Run: [2/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] +Pass: Cold: 5.595138ms GPU, 5.598711ms CPU, 6.90s total GPU, 7.03s total wall, 1234x +Run: [3/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] +Warn: Current measurement timed out (15.01s) while over noise threshold (1.14% > 0.50%) +Pass: Cold: 5.936679ms GPU, 5.940334ms CPU, 14.75s total GPU, 15.01s total wall, 2485x +Run: [4/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] +Pass: Cold: 18.947469ms GPU, 18.951296ms CPU, 4.62s total GPU, 4.64s total wall, 244x +Run: [5/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] +Pass: Cold: 9.850844ms GPU, 9.854526ms CPU, 0.50s total GPU, 0.51s total wall, 51x +Run: [6/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] +Pass: Cold: 36.678007ms GPU, 36.682032ms CPU, 0.51s total GPU, 0.51s total wall, 14x +``` + +# Benchmark Results + +## Protobuf Many Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|----------------|---------|-----------|-------|-----------|-------|--------| +| 10000 | 10 | 544x | 1.895 ms | 0.68% | 1.891 ms | 0.68% | 10000 | +| 100000 | 10 | 1234x | 5.599 ms | 0.50% | 5.595 ms | 0.50% | 100000 | +| 10000 | 30 | 2485x | 5.940 ms | 1.14% | 5.937 ms | 1.14% | 10000 | +| 100000 | 30 | 244x | 18.951 ms | 0.50% | 18.947 ms | 0.50% | 100000 | +| 10000 | 50 | 51x | 9.855 ms | 0.17% | 9.851 ms | 0.17% | 10000 | +| 100000 | 50 | 14x | 36.682 ms | 0.13% | 36.678 ms | 0.13% | 100000 | diff --git a/perf-results/run_01_p0_batch_repeated_sync.txt b/perf-results/run_01_p0_batch_repeated_sync.txt new file mode 100644 index 0000000000..666d6829ff --- /dev/null +++ b/perf-results/run_01_p0_batch_repeated_sync.txt @@ -0,0 +1,118 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] +Pass: Cold: 0.717111ms GPU, 0.720729ms CPU, 1.54s total GPU, 1.79s total wall, 2144x +Run: [2/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] +Pass: Cold: 1.529744ms GPU, 1.533401ms CPU, 0.50s total GPU, 0.53s total wall, 327x +Run: [3/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] +Pass: Cold: 4.439003ms GPU, 4.442652ms CPU, 6.75s total GPU, 6.89s total wall, 1520x +Run: [4/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] +Pass: Cold: 2.655585ms GPU, 2.659248ms CPU, 6.20s total GPU, 6.44s total wall, 2336x +Run: [5/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] +Warn: Current measurement timed out (15.01s) while over noise threshold (5.03% > 0.50%) +Pass: Cold: 7.199181ms GPU, 7.202834ms CPU, 14.82s total GPU, 15.01s total wall, 2058x +Run: [6/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] +Pass: Cold: 26.918621ms GPU, 26.922675ms CPU, 0.51s total GPU, 0.51s total wall, 19x +Run: [7/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] +Pass: Cold: 10.300780ms GPU, 10.304434ms CPU, 0.50s total GPU, 0.51s total wall, 49x +Run: [8/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] +Pass: Cold: 29.027086ms GPU, 29.031588ms CPU, 0.52s total GPU, 0.52s total wall, 18x +Run: [9/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] +Pass: Cold: 116.469964ms GPU, 116.475358ms CPU, 1.28s total GPU, 1.28s total wall, 11x +Run: [10/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] +Pass: Cold: 0.723146ms GPU, 0.726771ms CPU, 0.50s total GPU, 0.58s total wall, 692x +Run: [11/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] +Pass: Cold: 1.589212ms GPU, 1.593000ms CPU, 0.50s total GPU, 0.53s total wall, 315x +Run: [12/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] +Pass: Cold: 4.805955ms GPU, 4.809632ms CPU, 6.38s total GPU, 6.51s total wall, 1328x +Run: [13/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] +Pass: Cold: 2.522655ms GPU, 2.526255ms CPU, 0.50s total GPU, 0.52s total wall, 199x +Run: [14/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] +Pass: Cold: 4.899035ms GPU, 4.902721ms CPU, 0.50s total GPU, 0.51s total wall, 103x +Run: [15/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] +Warn: Current measurement timed out (15.01s) while over noise threshold (1.47% > 0.50%) +Pass: Cold: 14.223304ms GPU, 14.227113ms CPU, 14.93s total GPU, 15.01s total wall, 1050x +Run: [16/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] +Pass: Cold: 11.633095ms GPU, 11.636902ms CPU, 5.96s total GPU, 6.01s total wall, 512x +Run: [17/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] +Warn: Current measurement timed out (15.02s) while over noise threshold (3.23% > 0.50%) +Pass: Cold: 24.105674ms GPU, 24.109725ms CPU, 14.97s total GPU, 15.02s total wall, 621x +Run: [18/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] +Pass: Cold: 69.203103ms GPU, 69.208290ms CPU, 0.76s total GPU, 0.76s total wall, 11x +Run: [19/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] +Pass: Cold: 1.898922ms GPU, 1.902667ms CPU, 0.50s total GPU, 0.53s total wall, 264x +Run: [20/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] +Pass: Cold: 3.314200ms GPU, 3.318216ms CPU, 2.81s total GPU, 2.90s total wall, 848x +Run: [21/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] +Pass: Cold: 2.205836ms GPU, 2.209719ms CPU, 0.50s total GPU, 0.52s total wall, 227x +Run: [22/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] +Pass: Cold: 5.259358ms GPU, 5.263366ms CPU, 0.50s total GPU, 0.51s total wall, 96x +Run: [23/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] +Pass: Cold: 3.261910ms GPU, 3.265949ms CPU, 1.90s total GPU, 1.96s total wall, 583x +Run: [24/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] +Warn: Current measurement timed out (15.01s) while over noise threshold (0.85% > 0.50%) +Pass: Cold: 8.872888ms GPU, 8.877022ms CPU, 14.87s total GPU, 15.01s total wall, 1676x +``` + +# Benchmark Results + +## Protobuf Flat Scalars + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|------------|---------|------------|-------|------------|-------|--------| +| 10000 | 10 | 2144x | 720.729 us | 1.61% | 717.111 us | 1.62% | 10000 | +| 100000 | 10 | 327x | 1.533 ms | 0.36% | 1.530 ms | 0.36% | 100000 | +| 500000 | 10 | 1520x | 4.443 ms | 0.71% | 4.439 ms | 0.71% | 500000 | +| 10000 | 50 | 2336x | 2.659 ms | 2.30% | 2.656 ms | 2.30% | 10000 | +| 100000 | 50 | 2058x | 7.203 ms | 5.03% | 7.199 ms | 5.03% | 100000 | +| 500000 | 50 | 19x | 26.923 ms | 0.09% | 26.919 ms | 0.09% | 500000 | +| 10000 | 200 | 49x | 10.304 ms | 0.29% | 10.301 ms | 0.29% | 10000 | +| 100000 | 200 | 18x | 29.032 ms | 0.13% | 29.027 ms | 0.13% | 100000 | +| 500000 | 200 | 11x | 116.475 ms | 0.37% | 116.470 ms | 0.37% | 500000 | + +## Protobuf Nested Message + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|--------------|---------|------------|-------|------------|-------|--------| +| 10000 | 5 | 692x | 726.771 us | 0.50% | 723.146 us | 0.50% | 10000 | +| 100000 | 5 | 315x | 1.593 ms | 0.34% | 1.589 ms | 0.34% | 100000 | +| 500000 | 5 | 1328x | 4.810 ms | 0.57% | 4.806 ms | 0.57% | 500000 | +| 10000 | 20 | 199x | 2.526 ms | 0.27% | 2.523 ms | 0.27% | 10000 | +| 100000 | 20 | 103x | 4.903 ms | 0.29% | 4.899 ms | 0.29% | 100000 | +| 500000 | 20 | 1050x | 14.227 ms | 1.46% | 14.223 ms | 1.47% | 500000 | +| 10000 | 100 | 512x | 11.637 ms | 5.12% | 11.633 ms | 5.12% | 10000 | +| 100000 | 100 | 621x | 24.110 ms | 3.23% | 24.106 ms | 3.23% | 100000 | +| 500000 | 100 | 11x | 69.208 ms | 0.42% | 69.203 ms | 0.42% | 500000 | + +## Protobuf Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|-----------|---------|----------|-------|----------|-------|--------| +| 10000 | 1 | 264x | 1.903 ms | 0.48% | 1.899 ms | 0.48% | 10000 | +| 100000 | 1 | 848x | 3.318 ms | 0.71% | 3.314 ms | 0.71% | 100000 | +| 10000 | 5 | 227x | 2.210 ms | 0.29% | 2.206 ms | 0.29% | 10000 | +| 100000 | 5 | 96x | 5.263 ms | 0.19% | 5.259 ms | 0.19% | 100000 | +| 10000 | 20 | 583x | 3.266 ms | 0.50% | 3.262 ms | 0.50% | 10000 | +| 100000 | 20 | 1676x | 8.877 ms | 0.85% | 8.873 ms | 0.85% | 100000 | diff --git a/perf-results/run_01_p0_many_repeated.txt b/perf-results/run_01_p0_many_repeated.txt new file mode 100644 index 0000000000..ce41520b18 --- /dev/null +++ b/perf-results/run_01_p0_many_repeated.txt @@ -0,0 +1,48 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] +Pass: Cold: 1.748073ms GPU, 1.751753ms CPU, 1.62s total GPU, 1.73s total wall, 928x +Run: [2/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] +Warn: Current measurement timed out (15.00s) while over noise threshold (4.72% > 0.50%) +Pass: Cold: 5.443364ms GPU, 5.446938ms CPU, 14.72s total GPU, 15.00s total wall, 2705x +Run: [3/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] +Pass: Cold: 5.374920ms GPU, 5.378463ms CPU, 4.13s total GPU, 4.21s total wall, 768x +Run: [4/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] +Warn: Current measurement timed out (15.02s) while over noise threshold (0.65% > 0.50%) +Pass: Cold: 18.527305ms GPU, 18.531079ms CPU, 14.95s total GPU, 15.02s total wall, 807x +Run: [5/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] +Pass: Cold: 9.055705ms GPU, 9.059334ms CPU, 0.51s total GPU, 0.51s total wall, 56x +Run: [6/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] +Pass: Cold: 35.969858ms GPU, 35.973699ms CPU, 0.50s total GPU, 0.50s total wall, 14x +``` + +# Benchmark Results + +## Protobuf Many Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|----------------|---------|-----------|-------|-----------|-------|--------| +| 10000 | 10 | 928x | 1.752 ms | 0.80% | 1.748 ms | 0.80% | 10000 | +| 100000 | 10 | 2705x | 5.447 ms | 4.72% | 5.443 ms | 4.72% | 100000 | +| 10000 | 30 | 768x | 5.378 ms | 3.05% | 5.375 ms | 3.05% | 10000 | +| 100000 | 30 | 807x | 18.531 ms | 0.65% | 18.527 ms | 0.65% | 100000 | +| 10000 | 50 | 56x | 9.059 ms | 0.25% | 9.056 ms | 0.25% | 10000 | +| 100000 | 50 | 14x | 35.974 ms | 0.19% | 35.970 ms | 0.19% | 100000 | diff --git a/perf-results/run_01_p0_merge_reduce_scan.txt b/perf-results/run_01_p0_merge_reduce_scan.txt new file mode 100644 index 0000000000..9f4d4e8c93 --- /dev/null +++ b/perf-results/run_01_p0_merge_reduce_scan.txt @@ -0,0 +1,117 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] +Pass: Cold: 0.726904ms GPU, 0.730745ms CPU, 1.80s total GPU, 2.10s total wall, 2480x +Run: [2/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] +Pass: Cold: 1.536747ms GPU, 1.540522ms CPU, 1.18s total GPU, 1.26s total wall, 768x +Run: [3/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] +Pass: Cold: 4.443912ms GPU, 4.447723ms CPU, 14.43s total GPU, 14.76s total wall, 3248x +Run: [4/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] +Pass: Cold: 2.726921ms GPU, 2.730759ms CPU, 5.15s total GPU, 5.34s total wall, 1888x +Run: [5/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] +Warn: Current measurement timed out (15.01s) while over noise threshold (0.90% > 0.50%) +Pass: Cold: 7.246037ms GPU, 7.249840ms CPU, 14.81s total GPU, 15.01s total wall, 2044x +Run: [6/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] +Pass: Cold: 26.963141ms GPU, 26.967804ms CPU, 0.51s total GPU, 0.51s total wall, 19x +Run: [7/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] +Pass: Cold: 10.912177ms GPU, 10.916118ms CPU, 9.08s total GPU, 9.17s total wall, 832x +Run: [8/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] +Warn: Current measurement timed out (15.02s) while over noise threshold (0.86% > 0.50%) +Pass: Cold: 29.366643ms GPU, 29.370824ms CPU, 14.98s total GPU, 15.02s total wall, 510x +Run: [9/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] +Pass: Cold: 116.532303ms GPU, 116.538900ms CPU, 1.28s total GPU, 1.28s total wall, 11x +Run: [10/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] +Pass: Cold: 0.749640ms GPU, 0.753178ms CPU, 1.09s total GPU, 1.26s total wall, 1456x +Run: [11/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] +Pass: Cold: 1.638890ms GPU, 1.642742ms CPU, 4.20s total GPU, 4.46s total wall, 2560x +Run: [12/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] +Pass: Cold: 4.810134ms GPU, 4.814028ms CPU, 12.24s total GPU, 12.49s total wall, 2544x +Run: [13/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] +Pass: Cold: 2.629933ms GPU, 2.633856ms CPU, 2.99s total GPU, 3.11s total wall, 1136x +Run: [14/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] +Pass: Cold: 5.105757ms GPU, 5.109912ms CPU, 4.74s total GPU, 4.83s total wall, 928x +Run: [15/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] +Pass: Cold: 14.065493ms GPU, 14.069656ms CPU, 6.64s total GPU, 6.67s total wall, 472x +Run: [16/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] +Pass: Cold: 11.388964ms GPU, 11.393028ms CPU, 7.29s total GPU, 7.35s total wall, 640x +Run: [17/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] +Warn: Current measurement timed out (15.00s) while over noise threshold (1.33% > 0.50%) +Pass: Cold: 24.113780ms GPU, 24.118162ms CPU, 14.95s total GPU, 15.00s total wall, 620x +Run: [18/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] +Pass: Cold: 69.461100ms GPU, 69.466652ms CPU, 0.76s total GPU, 0.76s total wall, 11x +Run: [19/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] +Pass: Cold: 0.944527ms GPU, 0.948112ms CPU, 0.50s total GPU, 0.56s total wall, 530x +Run: [20/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] +Pass: Cold: 2.308945ms GPU, 2.312709ms CPU, 0.50s total GPU, 0.52s total wall, 217x +Run: [21/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] +Pass: Cold: 1.271388ms GPU, 1.275154ms CPU, 2.85s total GPU, 3.09s total wall, 2240x +Run: [22/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] +Pass: Cold: 3.967131ms GPU, 3.970874ms CPU, 7.30s total GPU, 7.49s total wall, 1840x +Run: [23/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] +Pass: Cold: 2.220037ms GPU, 2.223704ms CPU, 0.50s total GPU, 0.52s total wall, 226x +Run: [24/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] +Pass: Cold: 7.729027ms GPU, 7.733813ms CPU, 0.50s total GPU, 0.51s total wall, 65x +``` + +# Benchmark Results + +## Protobuf Flat Scalars + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|------------|---------|------------|-------|------------|-------|--------| +| 10000 | 10 | 2480x | 730.745 us | 2.70% | 726.904 us | 2.72% | 10000 | +| 100000 | 10 | 768x | 1.541 ms | 1.42% | 1.537 ms | 1.43% | 100000 | +| 500000 | 10 | 3248x | 4.448 ms | 6.56% | 4.444 ms | 6.56% | 500000 | +| 10000 | 50 | 1888x | 2.731 ms | 1.65% | 2.727 ms | 1.65% | 10000 | +| 100000 | 50 | 2044x | 7.250 ms | 0.90% | 7.246 ms | 0.90% | 100000 | +| 500000 | 50 | 19x | 26.968 ms | 0.20% | 26.963 ms | 0.20% | 500000 | +| 10000 | 200 | 832x | 10.916 ms | 2.51% | 10.912 ms | 2.51% | 10000 | +| 100000 | 200 | 510x | 29.371 ms | 0.86% | 29.367 ms | 0.86% | 100000 | +| 500000 | 200 | 11x | 116.539 ms | 0.14% | 116.532 ms | 0.14% | 500000 | + +## Protobuf Nested Message + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|--------------|---------|------------|-------|------------|-------|--------| +| 10000 | 5 | 1456x | 753.178 us | 1.51% | 749.640 us | 1.52% | 10000 | +| 100000 | 5 | 2560x | 1.643 ms | 0.90% | 1.639 ms | 0.89% | 100000 | +| 500000 | 5 | 2544x | 4.814 ms | 1.55% | 4.810 ms | 1.55% | 500000 | +| 10000 | 20 | 1136x | 2.634 ms | 1.91% | 2.630 ms | 1.91% | 10000 | +| 100000 | 20 | 928x | 5.110 ms | 1.47% | 5.106 ms | 1.47% | 100000 | +| 500000 | 20 | 472x | 14.070 ms | 0.50% | 14.065 ms | 0.50% | 500000 | +| 10000 | 100 | 640x | 11.393 ms | 4.83% | 11.389 ms | 4.83% | 10000 | +| 100000 | 100 | 620x | 24.118 ms | 1.33% | 24.114 ms | 1.33% | 100000 | +| 500000 | 100 | 11x | 69.467 ms | 0.14% | 69.461 ms | 0.14% | 500000 | + +## Protobuf Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|-----------|---------|------------|-------|------------|-------|--------| +| 10000 | 1 | 530x | 948.112 us | 0.35% | 944.527 us | 0.35% | 10000 | +| 100000 | 1 | 217x | 2.313 ms | 0.27% | 2.309 ms | 0.27% | 100000 | +| 10000 | 5 | 2240x | 1.275 ms | 1.13% | 1.271 ms | 1.13% | 10000 | +| 100000 | 5 | 1840x | 3.971 ms | 0.57% | 3.967 ms | 0.57% | 100000 | +| 10000 | 20 | 226x | 2.224 ms | 0.32% | 2.220 ms | 0.32% | 10000 | +| 100000 | 20 | 65x | 7.734 ms | 0.13% | 7.729 ms | 0.13% | 100000 | diff --git a/perf-results/run_02_p3_defer_error_check.txt b/perf-results/run_02_p3_defer_error_check.txt new file mode 100644 index 0000000000..dd47a5ba1b --- /dev/null +++ b/perf-results/run_02_p3_defer_error_check.txt @@ -0,0 +1,139 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] +Pass: Cold: 0.692966ms GPU, 0.696617ms CPU, 1.88s total GPU, 2.20s total wall, 2720x +Run: [2/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] +Pass: Cold: 1.491170ms GPU, 1.494817ms CPU, 4.75s total GPU, 5.07s total wall, 3184x +Run: [3/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] +Pass: Cold: 4.211614ms GPU, 4.215283ms CPU, 11.32s total GPU, 11.58s total wall, 2688x +Run: [4/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] +Pass: Cold: 2.584040ms GPU, 2.587704ms CPU, 8.14s total GPU, 8.47s total wall, 3152x +Run: [5/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] +Pass: Cold: 7.038277ms GPU, 7.041922ms CPU, 5.58s total GPU, 5.65s total wall, 793x +Run: [6/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] +Pass: Cold: 26.499879ms GPU, 26.504024ms CPU, 0.50s total GPU, 0.50s total wall, 19x +Run: [7/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] +Pass: Cold: 10.245305ms GPU, 10.249014ms CPU, 0.50s total GPU, 0.51s total wall, 49x +Run: [8/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] +Pass: Cold: 28.589312ms GPU, 28.593149ms CPU, 7.69s total GPU, 7.71s total wall, 269x +Run: [9/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] +Pass: Cold: 115.004578ms GPU, 115.010221ms CPU, 1.27s total GPU, 1.27s total wall, 11x +Run: [10/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] +Pass: Cold: 0.682600ms GPU, 0.686183ms CPU, 0.67s total GPU, 0.78s total wall, 976x +Run: [11/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] +Pass: Cold: 1.518160ms GPU, 1.522011ms CPU, 4.18s total GPU, 4.45s total wall, 2752x +Run: [12/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] +Pass: Cold: 4.583684ms GPU, 4.587425ms CPU, 9.11s total GPU, 9.30s total wall, 1987x +Run: [13/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] +Pass: Cold: 2.456863ms GPU, 2.460561ms CPU, 0.50s total GPU, 0.52s total wall, 204x +Run: [14/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] +Pass: Cold: 4.740401ms GPU, 4.744170ms CPU, 0.50s total GPU, 0.51s total wall, 106x +Run: [15/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] +Pass: Cold: 13.798471ms GPU, 13.802428ms CPU, 14.52s total GPU, 14.59s total wall, 1052x +Run: [16/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] +Pass: Cold: 10.622611ms GPU, 10.626416ms CPU, 0.51s total GPU, 0.51s total wall, 48x +Run: [17/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] +Pass: Cold: 23.489448ms GPU, 23.493497ms CPU, 13.15s total GPU, 13.20s total wall, 560x +Run: [18/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] +Pass: Cold: 69.207141ms GPU, 69.212410ms CPU, 5.74s total GPU, 5.75s total wall, 83x +Run: [19/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] +Pass: Cold: 0.839203ms GPU, 0.842775ms CPU, 0.50s total GPU, 0.57s total wall, 596x +Run: [20/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] +Pass: Cold: 2.152807ms GPU, 2.156525ms CPU, 0.50s total GPU, 0.52s total wall, 233x +Run: [21/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] +Pass: Cold: 1.148301ms GPU, 1.151898ms CPU, 0.50s total GPU, 0.54s total wall, 436x +Run: [22/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] +Pass: Cold: 3.803068ms GPU, 3.806759ms CPU, 7.05s total GPU, 7.23s total wall, 1853x +Run: [23/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] +Pass: Cold: 2.080413ms GPU, 2.084091ms CPU, 0.50s total GPU, 0.53s total wall, 241x +Run: [24/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] +Pass: Cold: 7.516108ms GPU, 7.520008ms CPU, 0.50s total GPU, 0.51s total wall, 67x +Run: [25/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] +Pass: Cold: 1.471735ms GPU, 1.475218ms CPU, 0.50s total GPU, 0.54s total wall, 340x +Run: [26/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] +Pass: Cold: 5.010848ms GPU, 5.014496ms CPU, 0.50s total GPU, 0.51s total wall, 100x +Run: [27/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] +Pass: Cold: 4.675919ms GPU, 4.679577ms CPU, 0.50s total GPU, 0.51s total wall, 107x +Run: [28/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] +Pass: Cold: 17.064845ms GPU, 17.068522ms CPU, 0.51s total GPU, 0.51s total wall, 30x +Run: [29/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] +Pass: Cold: 8.136901ms GPU, 8.140626ms CPU, 0.50s total GPU, 0.51s total wall, 62x +Run: [30/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] +Pass: Cold: 33.943523ms GPU, 33.947271ms CPU, 0.51s total GPU, 0.51s total wall, 15x +``` + +# Benchmark Results + +## Protobuf Flat Scalars + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|------------|---------|------------|-------|------------|-------|--------| +| 10000 | 10 | 2720x | 696.617 us | 1.30% | 692.966 us | 1.31% | 10000 | +| 100000 | 10 | 3184x | 1.495 ms | 5.80% | 1.491 ms | 5.81% | 100000 | +| 500000 | 10 | 2688x | 4.215 ms | 0.55% | 4.212 ms | 0.55% | 500000 | +| 10000 | 50 | 3152x | 2.588 ms | 1.31% | 2.584 ms | 1.31% | 10000 | +| 100000 | 50 | 793x | 7.042 ms | 0.50% | 7.038 ms | 0.50% | 100000 | +| 500000 | 50 | 19x | 26.504 ms | 0.07% | 26.500 ms | 0.07% | 500000 | +| 10000 | 200 | 49x | 10.249 ms | 0.34% | 10.245 ms | 0.34% | 10000 | +| 100000 | 200 | 269x | 28.593 ms | 0.50% | 28.589 ms | 0.50% | 100000 | +| 500000 | 200 | 11x | 115.010 ms | 0.32% | 115.005 ms | 0.32% | 500000 | + +## Protobuf Nested Message + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|--------------|---------|------------|-------|------------|-------|--------| +| 10000 | 5 | 976x | 686.183 us | 0.50% | 682.600 us | 0.51% | 10000 | +| 100000 | 5 | 2752x | 1.522 ms | 0.86% | 1.518 ms | 0.86% | 100000 | +| 500000 | 5 | 1987x | 4.587 ms | 0.50% | 4.584 ms | 0.50% | 500000 | +| 10000 | 20 | 204x | 2.461 ms | 0.24% | 2.457 ms | 0.24% | 10000 | +| 100000 | 20 | 106x | 4.744 ms | 0.15% | 4.740 ms | 0.15% | 100000 | +| 500000 | 20 | 1052x | 13.802 ms | 0.50% | 13.798 ms | 0.50% | 500000 | +| 10000 | 100 | 48x | 10.626 ms | 0.27% | 10.623 ms | 0.27% | 10000 | +| 100000 | 100 | 560x | 23.493 ms | 1.13% | 23.489 ms | 1.13% | 100000 | +| 500000 | 100 | 83x | 69.212 ms | 0.50% | 69.207 ms | 0.50% | 500000 | + +## Protobuf Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|-----------|---------|------------|-------|------------|-------|--------| +| 10000 | 1 | 596x | 842.775 us | 0.35% | 839.203 us | 0.35% | 10000 | +| 100000 | 1 | 233x | 2.157 ms | 0.27% | 2.153 ms | 0.27% | 100000 | +| 10000 | 5 | 436x | 1.152 ms | 0.47% | 1.148 ms | 0.47% | 10000 | +| 100000 | 5 | 1853x | 3.807 ms | 0.50% | 3.803 ms | 0.50% | 100000 | +| 10000 | 20 | 241x | 2.084 ms | 0.23% | 2.080 ms | 0.24% | 10000 | +| 100000 | 20 | 67x | 7.520 ms | 0.13% | 7.516 ms | 0.13% | 100000 | + +## Protobuf Many Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|----------------|---------|-----------|-------|-----------|-------|--------| +| 10000 | 10 | 340x | 1.475 ms | 0.27% | 1.472 ms | 0.27% | 10000 | +| 100000 | 10 | 100x | 5.014 ms | 0.49% | 5.011 ms | 0.49% | 100000 | +| 10000 | 30 | 107x | 4.680 ms | 0.14% | 4.676 ms | 0.14% | 10000 | +| 100000 | 30 | 30x | 17.069 ms | 0.13% | 17.065 ms | 0.13% | 100000 | +| 10000 | 50 | 62x | 8.141 ms | 0.17% | 8.137 ms | 0.17% | 10000 | +| 100000 | 50 | 15x | 33.947 ms | 0.09% | 33.944 ms | 0.09% | 100000 | diff --git a/perf-results/run_03_p1_combined_occurrence_scan.txt b/perf-results/run_03_p1_combined_occurrence_scan.txt new file mode 100644 index 0000000000..85e244b8fe --- /dev/null +++ b/perf-results/run_03_p1_combined_occurrence_scan.txt @@ -0,0 +1,143 @@ +# Devices + +## [0] `NVIDIA RTX 5880 Ada Generation` +* SM Version: 890 (PTX Version: 860) +* Number of SMs: 110 +* SM Default Clock Rate: 2460 MHz +* Global Memory: 45660 MiB Free / 48506 MiB Total +* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) +* Max Shared Memory: 100 KiB/SM, 48 KiB/Block +* L2 Cache Size: 98304 KiB +* Maximum Active Blocks: 24/SM +* Maximum Active Threads: 1536/SM, 1024/Block +* Available Registers: 65536/SM, 65536/Block +* ECC Enabled: No + +# Log + +``` +Run: [1/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] +Pass: Cold: 0.686546ms GPU, 0.690229ms CPU, 1.96s total GPU, 2.29s total wall, 2848x +Run: [2/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] +Pass: Cold: 1.501841ms GPU, 1.505566ms CPU, 0.50s total GPU, 0.53s total wall, 333x +Run: [3/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] +Pass: Cold: 4.243665ms GPU, 4.247329ms CPU, 11.81s total GPU, 12.09s total wall, 2784x +Run: [4/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] +Pass: Cold: 2.617960ms GPU, 2.621655ms CPU, 7.04s total GPU, 7.31s total wall, 2688x +Run: [5/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] +Pass: Cold: 7.109102ms GPU, 7.112786ms CPU, 9.78s total GPU, 9.91s total wall, 1376x +Run: [6/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] +Pass: Cold: 26.655070ms GPU, 26.659158ms CPU, 0.51s total GPU, 0.51s total wall, 19x +Run: [7/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] +Pass: Cold: 10.439957ms GPU, 10.443692ms CPU, 11.36s total GPU, 11.47s total wall, 1088x +Run: [8/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] +Warn: Current measurement timed out (15.02s) while over noise threshold (1.48% > 0.50%) +Pass: Cold: 29.198385ms GPU, 29.202193ms CPU, 14.98s total GPU, 15.02s total wall, 513x +Run: [9/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] +Pass: Cold: 115.994634ms GPU, 115.999994ms CPU, 7.42s total GPU, 7.43s total wall, 64x +Run: [10/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] +Pass: Cold: 0.682638ms GPU, 0.686217ms CPU, 1.07s total GPU, 1.25s total wall, 1568x +Run: [11/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] +Pass: Cold: 1.538363ms GPU, 1.542351ms CPU, 4.48s total GPU, 4.77s total wall, 2912x +Run: [12/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] +Pass: Cold: 4.651629ms GPU, 4.655404ms CPU, 13.02s total GPU, 13.30s total wall, 2800x +Run: [13/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] +Pass: Cold: 2.458447ms GPU, 2.462209ms CPU, 0.50s total GPU, 0.52s total wall, 204x +Run: [14/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] +Pass: Cold: 4.817362ms GPU, 4.821159ms CPU, 0.50s total GPU, 0.51s total wall, 104x +Run: [15/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] +Warn: Current measurement timed out (15.00s) while over noise threshold (3.64% > 0.50%) +Pass: Cold: 13.952591ms GPU, 13.956618ms CPU, 14.93s total GPU, 15.00s total wall, 1070x +Run: [16/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] +Pass: Cold: 11.263608ms GPU, 11.267472ms CPU, 5.95s total GPU, 6.00s total wall, 528x +Run: [17/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] +Pass: Cold: 24.756087ms GPU, 24.760204ms CPU, 13.86s total GPU, 13.91s total wall, 560x +Run: [18/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] +Warn: Current measurement timed out (15.00s) while over noise threshold (0.91% > 0.50%) +Pass: Cold: 70.025400ms GPU, 70.030674ms CPU, 14.99s total GPU, 15.00s total wall, 214x +Run: [19/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] +Pass: Cold: 0.852793ms GPU, 0.856320ms CPU, 0.50s total GPU, 0.57s total wall, 587x +Run: [20/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] +Pass: Cold: 2.281557ms GPU, 2.285314ms CPU, 0.50s total GPU, 0.52s total wall, 220x +Run: [21/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] +Pass: Cold: 1.177153ms GPU, 1.180749ms CPU, 1.45s total GPU, 1.58s total wall, 1232x +Run: [22/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] +Pass: Cold: 4.102615ms GPU, 4.106303ms CPU, 9.32s total GPU, 9.55s total wall, 2272x +Run: [23/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] +Pass: Cold: 2.129240ms GPU, 2.132977ms CPU, 0.50s total GPU, 0.52s total wall, 235x +Run: [24/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] +Pass: Cold: 7.738479ms GPU, 7.742156ms CPU, 0.50s total GPU, 0.51s total wall, 65x +Run: [25/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] +Pass: Cold: 1.637901ms GPU, 1.641626ms CPU, 0.50s total GPU, 0.53s total wall, 306x +Run: [26/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] +Pass: Cold: 5.579029ms GPU, 5.582809ms CPU, 4.64s total GPU, 4.73s total wall, 832x +Run: [27/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] +Pass: Cold: 4.620287ms GPU, 4.623976ms CPU, 0.50s total GPU, 0.51s total wall, 109x +Run: [28/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] +Pass: Cold: 16.545099ms GPU, 16.548836ms CPU, 12.44s total GPU, 12.51s total wall, 752x +Run: [29/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] +Pass: Cold: 7.957812ms GPU, 7.961486ms CPU, 4.46s total GPU, 4.51s total wall, 560x +Run: [30/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] +Warn: Current measurement timed out (15.01s) while over noise threshold (0.69% > 0.50%) +Pass: Cold: 27.663957ms GPU, 27.667840ms CPU, 14.97s total GPU, 15.01s total wall, 541x +``` + +# Benchmark Results + +## Protobuf Flat Scalars + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|------------|---------|------------|-------|------------|-------|--------| +| 10000 | 10 | 2848x | 690.229 us | 1.41% | 686.546 us | 1.41% | 10000 | +| 100000 | 10 | 333x | 1.506 ms | 0.35% | 1.502 ms | 0.35% | 100000 | +| 500000 | 10 | 2784x | 4.247 ms | 0.73% | 4.244 ms | 0.73% | 500000 | +| 10000 | 50 | 2688x | 2.622 ms | 6.22% | 2.618 ms | 6.23% | 10000 | +| 100000 | 50 | 1376x | 7.113 ms | 1.14% | 7.109 ms | 1.14% | 100000 | +| 500000 | 50 | 19x | 26.659 ms | 0.06% | 26.655 ms | 0.06% | 500000 | +| 10000 | 200 | 1088x | 10.444 ms | 2.81% | 10.440 ms | 2.81% | 10000 | +| 100000 | 200 | 513x | 29.202 ms | 1.48% | 29.198 ms | 1.48% | 100000 | +| 500000 | 200 | 64x | 116.000 ms | 0.50% | 115.995 ms | 0.50% | 500000 | + +## Protobuf Nested Message + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|--------------|---------|------------|-------|------------|-------|--------| +| 10000 | 5 | 1568x | 686.217 us | 0.53% | 682.638 us | 0.54% | 10000 | +| 100000 | 5 | 2912x | 1.542 ms | 1.12% | 1.538 ms | 1.13% | 100000 | +| 500000 | 5 | 2800x | 4.655 ms | 0.55% | 4.652 ms | 0.55% | 500000 | +| 10000 | 20 | 204x | 2.462 ms | 0.23% | 2.458 ms | 0.23% | 10000 | +| 100000 | 20 | 104x | 4.821 ms | 0.28% | 4.817 ms | 0.28% | 100000 | +| 500000 | 20 | 1070x | 13.957 ms | 3.64% | 13.953 ms | 3.64% | 500000 | +| 10000 | 100 | 528x | 11.267 ms | 5.07% | 11.264 ms | 5.07% | 10000 | +| 100000 | 100 | 560x | 24.760 ms | 2.25% | 24.756 ms | 2.25% | 100000 | +| 500000 | 100 | 214x | 70.031 ms | 0.91% | 70.025 ms | 0.91% | 500000 | + +## Protobuf Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|-----------|---------|------------|-------|------------|-------|--------| +| 10000 | 1 | 587x | 856.320 us | 0.36% | 852.793 us | 0.37% | 10000 | +| 100000 | 1 | 220x | 2.285 ms | 0.22% | 2.282 ms | 0.22% | 100000 | +| 10000 | 5 | 1232x | 1.181 ms | 0.63% | 1.177 ms | 0.63% | 10000 | +| 100000 | 5 | 2272x | 4.106 ms | 0.65% | 4.103 ms | 0.65% | 100000 | +| 10000 | 20 | 235x | 2.133 ms | 0.22% | 2.129 ms | 0.22% | 10000 | +| 100000 | 20 | 65x | 7.742 ms | 0.30% | 7.738 ms | 0.30% | 100000 | + +## Protobuf Many Repeated Fields + +### [0] NVIDIA RTX 5880 Ada Generation + +| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | +|----------|----------------|---------|-----------|-------|-----------|-------|--------| +| 10000 | 10 | 306x | 1.642 ms | 0.31% | 1.638 ms | 0.31% | 10000 | +| 100000 | 10 | 832x | 5.583 ms | 0.75% | 5.579 ms | 0.75% | 100000 | +| 10000 | 30 | 109x | 4.624 ms | 0.21% | 4.620 ms | 0.21% | 10000 | +| 100000 | 30 | 752x | 16.549 ms | 1.02% | 16.545 ms | 1.02% | 100000 | +| 10000 | 50 | 560x | 7.961 ms | 2.99% | 7.958 ms | 2.99% | 10000 | +| 100000 | 50 | 541x | 27.668 ms | 0.69% | 27.664 ms | 0.69% | 100000 | diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 05cb5a8833..4bf4bd622b 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -139,17 +139,12 @@ std::unique_ptr make_binary_column(std::vector( - cudf::data_type{cudf::type_id::UINT8}, - total_bytes, - std::move(d_data), - rmm::device_buffer{}, - 0); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, - static_cast(h_offsets.size()), - std::move(d_offsets), - rmm::device_buffer{}, - 0); + cudf::data_type{cudf::type_id::UINT8}, total_bytes, std::move(d_data), rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + static_cast(h_offsets.size()), + std::move(d_offsets), + rmm::device_buffer{}, + 0); return cudf::make_lists_column(static_cast(messages.size()), std::move(offsets_col), @@ -187,14 +182,14 @@ struct FlatScalarCase { cudf::type_id::FLOAT32, cudf::type_id::FLOAT64, cudf::type_id::BOOL8}; - int wt_for_type[] = {0 /*WT_VARINT*/, 0, 5 /*WT_32BIT*/, 1 /*WT_64BIT*/, 0}; + int wt_for_type[] = {0 /*WT_VARINT*/, 0, 5 /*WT_32BIT*/, 1 /*WT_64BIT*/, 0}; int fn = 1; for (int i = 0; i < num_int_fields; i++, fn++) { - int ti = i % 5; - auto ty = int_types[ti]; - int wt = wt_for_type[ti]; - int enc = spark_rapids_jni::ENC_DEFAULT; + int ti = i % 5; + auto ty = int_types[ti]; + int wt = wt_for_type[ti]; + int enc = spark_rapids_jni::ENC_DEFAULT; if (ty == cudf::type_id::FLOAT32) enc = spark_rapids_jni::ENC_FIXED; if (ty == cudf::type_id::FLOAT64) enc = spark_rapids_jni::ENC_FIXED; ctx.schema.push_back({fn, -1, 0, wt, ty, enc, false, false, false}); @@ -284,8 +279,7 @@ struct NestedMessageCase { for (int i = 0; i < num_inner_fields; i++) { int ti = i % 3; - ctx.schema.push_back( - {i + 1, 2, 1, inner_wt[ti], inner_types[ti], 0, false, false, false}); + ctx.schema.push_back({i + 1, 2, 1, inner_wt[ti], inner_types[ti], 0, false, false, false}); } size_t n = ctx.schema.size(); @@ -310,7 +304,8 @@ struct NestedMessageCase { auto random_string = [&](int len) { std::string s(len, ' '); - for (int c = 0; c < len; c++) s[c] = alphabet[rng() % alphabet.size()]; + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; return s; }; @@ -393,7 +388,8 @@ struct RepeatedFieldCase { auto random_string = [&](int len) { std::string s(len, ' '); - for (int c = 0; c < len; c++) s[c] = alphabet[rng() % alphabet.size()]; + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; return s; }; @@ -414,7 +410,8 @@ struct RepeatedFieldCase { { int n = vary(avg_tags_per_row); std::vector tags(n); - for (auto& t : tags) t = int_dist(rng); + for (auto& t : tags) + t = int_dist(rng); if (n > 0) encode_packed_repeated_int32(buf, 2, tags); } @@ -442,6 +439,98 @@ struct RepeatedFieldCase { } }; +// Case 4: Many repeated fields — stress-tests per-repeated-field sync overhead. +// message WideRepeatedMessage { +// int32 id = 1; +// repeated int32 r_int_1 = 2; +// repeated int32 r_int_2 = 3; +// ... +// repeated string r_str_1 = N; +// repeated string r_str_2 = N+1; +// ... +// } +struct ManyRepeatedFieldsCase { + int num_repeated_int; + int num_repeated_str; + + spark_rapids_jni::ProtobufDecodeContext build_context() const + { + spark_rapids_jni::ProtobufDecodeContext ctx; + ctx.fail_on_errors = true; + + int fn = 1; + // idx 0: id (scalar) + ctx.schema.push_back({fn++, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + + for (int i = 0; i < num_repeated_int; i++) { + ctx.schema.push_back({fn++, -1, 0, 0, cudf::type_id::INT32, 0, true, false, false}); + } + for (int i = 0; i < num_repeated_str; i++) { + ctx.schema.push_back({fn++, -1, 0, 2, cudf::type_id::STRING, 0, true, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, + int avg_elems_per_field, + std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(3, 15); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + int fn = 1; + + encode_varint_field(buf, fn++, int_dist(rng)); + + for (int i = 0; i < num_repeated_int; i++) { + int cur_fn = fn++; + int n = vary(avg_elems_per_field); + if (n > 0) { + std::vector vals(n); + for (auto& v : vals) + v = int_dist(rng); + encode_packed_repeated_int32(buf, cur_fn, vals); + } + } + for (int i = 0; i < num_repeated_str; i++) { + int cur_fn = fn++; + int n = vary(avg_elems_per_field); + for (int j = 0; j < n; j++) { + encode_string_field(buf, cur_fn, random_string(str_len_dist(rng))); + } + } + } + return messages; + } +}; + } // anonymous namespace // =========================================================================== @@ -449,20 +538,21 @@ struct RepeatedFieldCase { // =========================================================================== static void BM_protobuf_flat_scalars(nvbench::state& state) { - auto const num_rows = static_cast(state.get_int64("num_rows")); - auto const num_fields = static_cast(state.get_int64("num_fields")); - int const num_str = std::max(1, num_fields / 10); - int const num_int = num_fields - num_str; + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_fields = static_cast(state.get_int64("num_fields")); + int const num_str = std::max(1, num_fields / 10); + int const num_int = num_fields - num_str; FlatScalarCase flat_case{num_int, num_str}; auto ctx = flat_case.build_context(); std::mt19937 rng(42); - auto messages = flat_case.generate_messages(num_rows, rng); - auto binary_col = make_binary_column(messages); + auto messages = flat_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); size_t total_bytes = 0; - for (auto const& m : messages) total_bytes += m.size(); + for (auto const& m : messages) + total_bytes += m.size(); auto stream = cudf::get_default_stream(); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); @@ -484,18 +574,19 @@ NVBENCH_BENCH(BM_protobuf_flat_scalars) // =========================================================================== static void BM_protobuf_nested(nvbench::state& state) { - auto const num_rows = static_cast(state.get_int64("num_rows")); - auto const inner_fields = static_cast(state.get_int64("inner_fields")); + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const inner_fields = static_cast(state.get_int64("inner_fields")); NestedMessageCase nested_case{inner_fields}; auto ctx = nested_case.build_context(); std::mt19937 rng(42); - auto messages = nested_case.generate_messages(num_rows, rng); - auto binary_col = make_binary_column(messages); + auto messages = nested_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); size_t total_bytes = 0; - for (auto const& m : messages) total_bytes += m.size(); + for (auto const& m : messages) + total_bytes += m.size(); auto stream = cudf::get_default_stream(); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); @@ -524,11 +615,12 @@ static void BM_protobuf_repeated(nvbench::state& state) auto ctx = rep_case.build_context(); std::mt19937 rng(42); - auto messages = rep_case.generate_messages(num_rows, rng); - auto binary_col = make_binary_column(messages); + auto messages = rep_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); size_t total_bytes = 0; - for (auto const& m : messages) total_bytes += m.size(); + for (auto const& m : messages) + total_bytes += m.size(); auto stream = cudf::get_default_stream(); state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); @@ -544,3 +636,40 @@ NVBENCH_BENCH(BM_protobuf_repeated) .set_name("Protobuf Repeated Fields") .add_int64_axis("num_rows", {10'000, 100'000}) .add_int64_axis("avg_items", {1, 5, 20}); + +// =========================================================================== +// Benchmark 4: Many repeated fields — measures per-field sync overhead at scale +// =========================================================================== +static void BM_protobuf_many_repeated(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_rep_fields = static_cast(state.get_int64("num_rep_fields")); + + int const num_rep_str = std::max(1, num_rep_fields / 5); + int const num_rep_int = num_rep_fields - num_rep_str; + + ManyRepeatedFieldsCase many_case{num_rep_int, num_rep_str}; + auto ctx = many_case.build_context(); + + std::mt19937 rng(42); + auto messages = many_case.generate_messages(num_rows, /*avg_elems=*/3, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_many_repeated) + .set_name("Protobuf Many Repeated Fields") + .add_int64_axis("num_rows", {10'000, 100'000}) + .add_int64_axis("num_rep_fields", {10, 30, 50}); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 953083a676..069ecf14fa 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -141,15 +141,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& default: return "Protobuf decode error: unknown error"; } }; - auto check_error_and_throw = [&]() { - if (!fail_on_errors) return; - int h_error = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - &h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); - stream.synchronize(); - if (h_error != 0) { throw cudf::logic_error(error_message(h_error)); } - }; - // Enum validation support (PERMISSIVE mode) bool has_enum_fields = std::any_of( enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); @@ -199,7 +190,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_nested, d_nested_indices.data(), d_error.data()); - check_error_and_throw(); } // Store decoded columns by schema index for ordered assembly at the end. @@ -242,7 +232,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& static_cast(h_field_lookup.size()), d_locations.data(), d_error.data()); - check_error_and_throw(); // Check required fields (after scan pass) { @@ -267,7 +256,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.value())); check_required_fields_kernel<<>>( d_locations.data(), d_is_required.data(), num_scalar, num_rows, d_error.data()); - check_error_and_throw(); } } @@ -408,53 +396,95 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } - // Process repeated fields + // Process repeated fields (three-phase: offsets → combined scan → build columns) if (num_repeated > 0) { + // Phase A: Compute per-row offsets for each repeated field. + struct repeated_field_work { + int schema_idx; + int32_t total_count{0}; + rmm::device_uvector counts; + rmm::device_uvector offsets; + std::unique_ptr> occurrences; + + repeated_field_work(int si, + cudf::size_type n, + rmm::cuda_stream_view s, + rmm::device_async_resource_ref m) + : schema_idx(si), counts(n, s, m), offsets(n + 1, s, m) + { + } + }; + + std::vector> rep_work; + rep_work.reserve(num_repeated); + for (int ri = 0; ri < num_repeated; ri++) { - int schema_idx = repeated_field_indices[ri]; - auto element_type = schema_output_types[schema_idx]; + int schema_idx = repeated_field_indices[ri]; + auto& w = *rep_work.emplace_back( + std::make_unique(schema_idx, num_rows, stream, mr)); - // Get per-row counts for this repeated field entirely on GPU (performance fix!) - rmm::device_uvector d_field_counts(num_rows, stream, mr); thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - d_field_counts.data(), + w.counts.data(), extract_strided_count{d_repeated_info.data(), ri, num_repeated}); - int64_t total_count = thrust::reduce( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), int64_t{0}); - CUDF_EXPECTS(total_count <= std::numeric_limits::max(), - "Total repeated element count exceeds INT32_MAX"); + CUDF_CUDA_TRY(cudaMemsetAsync(w.offsets.data(), 0, sizeof(int32_t), stream.value())); + thrust::inclusive_scan( + rmm::exec_policy(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); - if (total_count > 0) { - // Build offsets for occurrence scanning on GPU (performance fix!) - rmm::device_uvector d_occ_offsets(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy(stream), - d_field_counts.begin(), - d_field_counts.end(), - d_occ_offsets.data(), - 0); - // Set last element - int32_t total_count_i32 = static_cast(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_occ_offsets.data() + num_rows, - &total_count_i32, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(&w.total_count, + w.offsets.data() + num_rows, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + } + stream.synchronize(); - // Scan for all occurrences - rmm::device_uvector d_occurrences(total_count, stream, mr); - scan_repeated_field_occurrences_kernel<<>>( - *d_in, - d_schema.data(), - schema_idx, - 0, - d_occ_offsets.data(), - d_occurrences.data(), - d_error.data()); - - check_error_and_throw(); + // Phase B: Allocate occurrence buffers and launch ONE combined scan kernel. + std::vector h_scan_descs; + h_scan_descs.reserve(num_repeated); + + for (auto& wp : rep_work) { + if (wp->total_count > 0) { + wp->occurrences = + std::make_unique>( + wp->total_count, stream, mr); + h_scan_descs.push_back({schema[wp->schema_idx].field_number, + schema[wp->schema_idx].wire_type, + wp->offsets.data(), + wp->occurrences->data()}); + } + } + + if (!h_scan_descs.empty()) { + rmm::device_uvector d_scan_descs( + h_scan_descs.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_scan_descs.data(), + h_scan_descs.data(), + h_scan_descs.size() * sizeof(h_scan_descs[0]), + cudaMemcpyHostToDevice, + stream.value())); + + scan_all_repeated_occurrences_kernel<<>>( + *d_in, + d_schema.data(), + 0, + d_scan_descs.data(), + static_cast(h_scan_descs.size()), + d_error.data()); + } + + // Phase C: Build columns per field. + for (int ri = 0; ri < num_repeated; ri++) { + auto& w = *rep_work[ri]; + int schema_idx = w.schema_idx; + auto element_type = schema_output_types[schema_idx]; + int32_t total_count = w.total_count; + auto& d_field_counts = w.counts; + + if (total_count > 0) { + auto& d_occurrences = *w.occurrences; // Build the appropriate column type based on element type auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index bc9b1dc86f..e260e337b2 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -115,6 +115,17 @@ struct repeated_occurrence { int32_t length; // Length of the field data }; +/** + * Per-field descriptor passed to the combined occurrence scan kernel. + * Contains device pointers so the kernel can write to each field's output. + */ +struct repeated_field_scan_desc { + int field_number; + int wire_type; + int32_t const* row_offsets; // Pre-computed prefix-sum offsets [num_rows + 1] + repeated_occurrence* occurrences; // Output buffer [total_count] +}; + /** * Device-side descriptor for nested schema fields. */ @@ -858,6 +869,13 @@ __global__ void scan_repeated_field_occurrences_kernel(cudf::column_device_view repeated_occurrence* occurrences, int* error_flag); +__global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int depth_level, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag); + __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, cudf::size_type const* parent_row_offsets, cudf::size_type parent_base_offset, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index d5095972f3..e526f55912 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -441,6 +441,79 @@ __global__ void scan_repeated_field_occurrences_kernel( } } +/** + * Combined occurrence scan: scans each message ONCE and writes occurrences for ALL + * repeated fields simultaneously. Replaces N separate scan_repeated_field_occurrences_kernel + * launches with a single kernel, eliminating N-1 redundant full-message scans. + */ +__global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int depth_level, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + if (in.nullable() && in.is_null(row)) { return; } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + // Per-field write indices, initialized from the pre-computed offsets. + // Use a fixed-size stack array to avoid dynamic allocation. + // MAX_REPEATED_SCAN_FIELDS should be generous enough for practical schemas. + constexpr int MAX_STACK_FIELDS = 128; + int write_idx[MAX_STACK_FIELDS]; + int actual_fields = num_scan_fields < MAX_STACK_FIELDS ? num_scan_fields : MAX_STACK_FIELDS; + for (int f = 0; f < actual_fields; f++) { + write_idx[f] = scan_descs[f].row_offsets[row]; + } + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + int fn = tag.field_number; + int wt = tag.wire_type; + + for (int f = 0; f < actual_fields; f++) { + if (scan_descs[f].field_number == fn) { + int target_wt = scan_descs[f].wire_type; + bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); + if (is_packed || wt == target_wt) { + if (!scan_repeated_element(cur, + msg_end, + bytes + start, + wt, + target_wt, + static_cast(row), + scan_descs[f].occurrences, + write_idx[f], + error_flag)) { + return; + } + } + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + // ============================================================================ // Nested message scanning kernels // ============================================================================ From 2b113c7155eae3e1553db292ec6cca965d09b028 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 16:01:36 +0800 Subject: [PATCH 043/107] O1 field number lookup in nested Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 79 +++++++++++++--- src/main/cpp/src/protobuf_common.cuh | 31 ++++++- src/main/cpp/src/protobuf_kernels.cu | 129 ++++++++++++++++++--------- 3 files changed, 180 insertions(+), 59 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 069ecf14fa..f87b293909 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -177,19 +177,49 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.value())); } - // Count repeated fields at depth 0 + // Count repeated fields at depth 0 (with O(1) field_number lookup tables) + rmm::device_uvector d_fn_to_rep(0, stream, mr); + rmm::device_uvector d_fn_to_nested(0, stream, mr); + if (num_repeated > 0 || num_nested > 0) { - count_repeated_fields_kernel<<>>(*d_in, - d_schema.data(), - num_fields, - 0, // depth_level - d_repeated_info.data(), - num_repeated, - d_repeated_indices.data(), - d_nested_locations.data(), - num_nested, - d_nested_indices.data(), - d_error.data()); + auto h_fn_to_rep = protobuf_detail::build_index_lookup_table( + schema.data(), repeated_field_indices.data(), num_repeated); + auto h_fn_to_nested = protobuf_detail::build_index_lookup_table( + schema.data(), nested_field_indices.data(), num_nested); + + if (!h_fn_to_rep.empty()) { + d_fn_to_rep = rmm::device_uvector(h_fn_to_rep.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_fn_to_rep.data(), + h_fn_to_rep.data(), + h_fn_to_rep.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + if (!h_fn_to_nested.empty()) { + d_fn_to_nested = rmm::device_uvector(h_fn_to_nested.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_fn_to_nested.data(), + h_fn_to_nested.data(), + h_fn_to_nested.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + + count_repeated_fields_kernel<<>>( + *d_in, + d_schema.data(), + num_fields, + 0, + d_repeated_info.data(), + num_repeated, + d_repeated_indices.data(), + d_nested_locations.data(), + num_nested, + d_nested_indices.data(), + d_error.data(), + d_fn_to_rep.data(), + static_cast(d_fn_to_rep.size()), + d_fn_to_nested.data(), + static_cast(d_fn_to_nested.size())); } // Store decoded columns by schema index for ordered assembly at the end. @@ -466,13 +496,36 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& cudaMemcpyHostToDevice, stream.value())); + // Build field_number -> scan_desc_index lookup for the combined kernel + int max_scan_fn = 0; + for (auto const& sd : h_scan_descs) { + max_scan_fn = std::max(max_scan_fn, sd.field_number); + } + rmm::device_uvector d_fn_to_scan(0, stream, mr); + int fn_to_scan_size = 0; + if (max_scan_fn <= protobuf_detail::FIELD_LOOKUP_TABLE_MAX) { + std::vector h_fn_to_scan(max_scan_fn + 1, -1); + for (int i = 0; i < static_cast(h_scan_descs.size()); i++) { + h_fn_to_scan[h_scan_descs[i].field_number] = i; + } + d_fn_to_scan = rmm::device_uvector(h_fn_to_scan.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_fn_to_scan.data(), + h_fn_to_scan.data(), + h_fn_to_scan.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + fn_to_scan_size = static_cast(h_fn_to_scan.size()); + } + scan_all_repeated_occurrences_kernel<<>>( *d_in, d_schema.data(), 0, d_scan_descs.data(), static_cast(h_scan_descs.size()), - d_error.data()); + d_error.data(), + d_fn_to_scan.data(), + fn_to_scan_size); } // Phase C: Build columns per field. diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index e260e337b2..88ec29f0c9 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -355,6 +355,27 @@ __device__ inline uint64_t load_le(uint8_t const* p) // Field number lookup table helpers // ============================================================================ +/** + * Build a host-side direct-mapped lookup table: field_number -> local_index, + * given an array of schema indices and the schema itself. + * Returns an empty vector if the max field number exceeds the threshold. + */ +inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, + int const* field_indices, + int num_indices) +{ + int max_fn = 0; + for (int i = 0; i < num_indices; i++) { + max_fn = std::max(max_fn, schema[field_indices[i]].field_number); + } + if (max_fn > FIELD_LOOKUP_TABLE_MAX) return {}; + std::vector table(max_fn + 1, -1); + for (int i = 0; i < num_indices; i++) { + table[schema[field_indices[i]].field_number] = i; + } + return table; +} + /** * Build a host-side direct-mapped lookup table: field_number -> field_index. * Returns an empty vector if the max field number exceeds the threshold. @@ -859,7 +880,11 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in field_location* nested_locations, int num_nested_fields, int const* nested_field_indices, - int* error_flag); + int* error_flag, + int const* fn_to_rep_idx = nullptr, + int fn_to_rep_size = 0, + int const* fn_to_nested_idx = nullptr, + int fn_to_nested_size = 0); __global__ void scan_repeated_field_occurrences_kernel(cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, @@ -874,7 +899,9 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co int depth_level, repeated_field_scan_desc const* scan_descs, int num_scan_fields, - int* error_flag); + int* error_flag, + int const* fn_to_desc_idx = nullptr, + int fn_to_desc_size = 0); __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, cudf::size_type const* parent_row_offsets, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index e526f55912..0d70893637 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -280,21 +280,24 @@ __device__ bool scan_repeated_element(uint8_t const* cur, * Count occurrences of repeated fields in each row. * Also records locations of nested message fields for hierarchical processing. * - * @note Time complexity: O(message_length * (num_repeated_fields + num_nested_fields)) per row. + * Optional lookup tables (fn_to_rep_idx, fn_to_nested_idx) provide O(1) field_number + * to local index mapping. When nullptr, falls back to linear search. */ -__global__ void count_repeated_fields_kernel( - cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int num_fields, - int depth_level, // Which depth level we're processing - repeated_field_info* repeated_info, // [num_rows * num_repeated_fields_at_this_depth] - int num_repeated_fields, // Number of repeated fields at this depth - int const* repeated_field_indices, // Indices into schema for repeated fields at this depth - field_location* - nested_locations, // Locations of nested messages for next depth [num_rows * num_nested] - int num_nested_fields, // Number of nested message fields at this depth - int const* nested_field_indices, // Indices into schema for nested message fields - int* error_flag) +__global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; @@ -328,9 +331,11 @@ __global__ void count_repeated_fields_kernel( int fn = tag.field_number; int wt = tag.wire_type; - for (int i = 0; i < num_repeated_fields; i++) { - int schema_idx = repeated_field_indices[i]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + // Lookup repeated field by field_number + if (fn_to_rep_idx != nullptr && fn > 0 && fn < fn_to_rep_size) { + int i = fn_to_rep_idx[fn]; + if (i >= 0) { + int schema_idx = repeated_field_indices[i]; if (!count_repeated_element(cur, msg_end, wt, @@ -340,26 +345,50 @@ __global__ void count_repeated_fields_kernel( return; } } - } - - // Check nested message fields at this depth (last one wins for non-repeated) - for (int i = 0; i < num_nested_fields; i++) { - int schema_idx = nested_field_indices[i]; - if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - if (wt != WT_LEN) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; + } else { + for (int i = 0; i < num_repeated_fields; i++) { + int schema_idx = repeated_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + if (!count_repeated_element(cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[row * num_repeated_fields + i], + error_flag)) { + return; + } } + } + } - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; + // Check nested message fields at this depth + auto handle_nested = [&](int i) { + if (wt != WT_LEN) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; + nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; + return true; + }; + + if (fn_to_nested_idx != nullptr && fn > 0 && fn < fn_to_nested_size) { + int i = fn_to_nested_idx[fn]; + if (i >= 0) { + if (!handle_nested(i)) return; + } + } else { + for (int i = 0; i < num_nested_fields; i++) { + int schema_idx = nested_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + if (!handle_nested(i)) return; } - - int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; - nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; } } @@ -451,7 +480,9 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co int depth_level, repeated_field_scan_desc const* scan_descs, int num_scan_fields, - int* error_flag) + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; @@ -485,12 +516,11 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co int fn = tag.field_number; int wt = tag.wire_type; - for (int f = 0; f < actual_fields; f++) { - if (scan_descs[f].field_number == fn) { - int target_wt = scan_descs[f].wire_type; - bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); - if (is_packed || wt == target_wt) { - if (!scan_repeated_element(cur, + auto try_scan = [&](int f) -> bool { + int target_wt = scan_descs[f].wire_type; + bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); + if (is_packed || wt == target_wt) { + return scan_repeated_element(cur, msg_end, bytes + start, wt, @@ -498,9 +528,20 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co static_cast(row), scan_descs[f].occurrences, write_idx[f], - error_flag)) { - return; - } + error_flag); + } + return true; + }; + + if (fn_to_desc_idx != nullptr && fn > 0 && fn < fn_to_desc_size) { + int f = fn_to_desc_idx[fn]; + if (f >= 0) { + if (!try_scan(f)) return; + } + } else { + for (int f = 0; f < actual_fields; f++) { + if (scan_descs[f].field_number == fn) { + if (!try_scan(f)) return; } } } From ce408d678544013478a2ec30490fb97c80d4c852 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 16:06:00 +0800 Subject: [PATCH 044/107] Remove perf-results --- perf-results/RESULTS.md | 222 ------------------ perf-results/run_00_baseline.txt | 115 --------- .../run_00_baseline_many_repeated.txt | 47 ---- .../run_01_p0_batch_repeated_sync.txt | 118 ---------- perf-results/run_01_p0_many_repeated.txt | 48 ---- perf-results/run_01_p0_merge_reduce_scan.txt | 117 --------- perf-results/run_02_p3_defer_error_check.txt | 139 ----------- .../run_03_p1_combined_occurrence_scan.txt | 143 ----------- 8 files changed, 949 deletions(-) delete mode 100644 perf-results/RESULTS.md delete mode 100644 perf-results/run_00_baseline.txt delete mode 100644 perf-results/run_00_baseline_many_repeated.txt delete mode 100644 perf-results/run_01_p0_batch_repeated_sync.txt delete mode 100644 perf-results/run_01_p0_many_repeated.txt delete mode 100644 perf-results/run_01_p0_merge_reduce_scan.txt delete mode 100644 perf-results/run_02_p3_defer_error_check.txt delete mode 100644 perf-results/run_03_p1_combined_occurrence_scan.txt diff --git a/perf-results/RESULTS.md b/perf-results/RESULTS.md deleted file mode 100644 index b3a3864aa3..0000000000 --- a/perf-results/RESULTS.md +++ /dev/null @@ -1,222 +0,0 @@ -# Protobuf GPU Decoder Performance Optimization Log - -**GPU**: NVIDIA RTX 5880 Ada Generation (110 SMs, 48GB) -**Benchmark**: `PROTOBUF_DECODE_BENCH` (3 cases: Flat Scalars, Nested Message, Repeated Fields) - ---- - -## 00. Baseline - -**日期**: 2026-03-04 -**改动摘要**: 无改动,记录初始性能基准 -**原始输出**: `run_00_baseline.txt` - -### Flat Scalars - -| num_rows | num_fields | GPU Time (ms) | -|----------|------------|---------------| -| 10,000 | 10 | 0.693 | -| 100,000 | 10 | 1.502 | -| 500,000 | 10 | 4.385 | -| 10,000 | 50 | 2.549 | -| 100,000 | 50 | 7.030 | -| 500,000 | 50 | 26.646 | -| 10,000 | 200 | 10.033 | -| 100,000 | 200 | 28.587 | -| 500,000 | 200 | 115.022 | - -### Nested Message - -| num_rows | inner_fields | GPU Time (ms) | -|----------|--------------|---------------| -| 10,000 | 5 | 0.699 | -| 100,000 | 5 | 1.551 | -| 500,000 | 5 | 4.787 | -| 10,000 | 20 | 2.445 | -| 100,000 | 20 | 4.794 | -| 500,000 | 20 | 13.995 | -| 10,000 | 100 | 10.549 | -| 100,000 | 100 | 23.539 | -| 500,000 | 100 | 68.828 | - -### Repeated Fields - -| num_rows | avg_items | GPU Time (ms) | -|----------|-----------|---------------| -| 10,000 | 1 | 0.966 | -| 100,000 | 1 | 2.362 | -| 10,000 | 5 | 1.292 | -| 100,000 | 5 | 3.999 | -| 10,000 | 20 | 2.221 | -| 100,000 | 20 | 7.696 | - ---- - -## 01. P0: Merge thrust::reduce + exclusive_scan into inclusive_scan - -**日期**: 2026-03-04 -**改动摘要**: 替换每个 repeated 字段的 `thrust::reduce` (implicit sync) + `thrust::exclusive_scan` + H2D copy 为单次 `thrust::inclusive_scan` + D2H copy,消除 CUB reduce 内部开销 -**改动文件**: protobuf.cu -**原始输出**: `run_01_p0_merge_reduce_scan.txt` - -### Flat Scalars - -| num_rows | num_fields | Before (ms) | After (ms) | Speedup | -|----------|------------|-------------|------------|---------| -| 10,000 | 10 | 0.693 | 0.727 | - | -| 100,000 | 10 | 1.502 | 1.537 | - | -| 500,000 | 10 | 4.385 | 4.444 | - | -| 10,000 | 50 | 2.549 | 2.727 | - | -| 100,000 | 50 | 7.030 | 7.246 | - | -| 500,000 | 50 | 26.646 | 26.963 | - | -| 10,000 | 200 | 10.033 | 10.912 | - | -| 100,000 | 200 | 28.587 | 29.367 | - | -| 500,000 | 200 | 115.022 | 116.532 | - | - -### Nested Message - -| num_rows | inner_fields | Before (ms) | After (ms) | Speedup | -|----------|--------------|-------------|------------|---------| -| 10,000 | 5 | 0.699 | 0.750 | - | -| 100,000 | 5 | 1.551 | 1.639 | - | -| 500,000 | 5 | 4.787 | 4.810 | - | -| 10,000 | 20 | 2.445 | 2.630 | - | -| 100,000 | 20 | 4.794 | 5.106 | - | -| 500,000 | 20 | 13.995 | 14.065 | - | -| 10,000 | 100 | 10.549 | 11.389 | - | -| 100,000 | 100 | 23.539 | 24.114 | - | -| 500,000 | 100 | 68.828 | 69.461 | - | - -### Repeated Fields - -| num_rows | avg_items | Before (ms) | After (ms) | Speedup | -|----------|-----------|-------------|------------|---------| -| 10,000 | 1 | 0.966 | 0.945 | 1.02x | -| 100,000 | 1 | 2.362 | 2.309 | 1.02x | -| 10,000 | 5 | 1.292 | 1.271 | 1.02x | -| 100,000 | 5 | 3.999 | 3.967 | 1.01x | -| 10,000 | 20 | 2.221 | 2.220 | 1.00x | -| 100,000 | 20 | 7.696 | 7.729 | 1.00x | - -### Many Repeated Fields (新增 case,更接近客户场景) - -| num_rows | num_rep_fields | Before (ms) | After (ms) | Speedup | -|----------|----------------|-------------|------------|---------| -| 10,000 | 10 | 1.891 | 1.748 | **1.08x** | -| 100,000 | 10 | 5.595 | 5.443 | 1.03x | -| 10,000 | 30 | 5.937 | 5.375 | **1.10x** | -| 100,000 | 30 | 18.947 | 18.527 | 1.02x | -| 10,000 | 50 | 9.851 | 9.056 | **1.09x** | -| 100,000 | 50 | 36.678 | 35.970 | 1.02x | - -**结论**: Flat/Nested 无变化(符合预期)。Repeated 3 字段 case 微小提升 (~1-2%)。**Many Repeated Fields case 在小行数下有 8-10% 提升**,大行数下 2-3% 提升——因为小行数时 sync 开销占比更大。对于客户的 ~98 个 repeated 字段 + ~13.5K 行/批的场景,此优化的收益在 small batch 区间(~10%)。 - -**UT**: 95/95 全过。 - ---- - -## 02. P3: Defer check_error_and_throw to end of decode - -**日期**: 2026-03-04 -**改动摘要**: 删除 4 处中间 `check_error_and_throw()` 调用(每次含 D2H copy + stream.synchronize),只保留函数末尾的一次最终 error check。`d_error` flag 在 GPU 上只写不读,中间检查完全冗余。 -**改动文件**: protobuf.cu -**原始输出**: `run_02_p3_defer_error_check.txt` - -### Flat Scalars (对比 baseline) - -| num_rows | num_fields | Baseline (ms) | P0+P3 (ms) | Speedup | -|----------|------------|---------------|------------|---------| -| 10,000 | 10 | 0.693 | 0.693 | 1.00x | -| 100,000 | 10 | 1.502 | 1.491 | 1.01x | -| 500,000 | 10 | 4.385 | 4.212 | **1.04x** | -| 10,000 | 50 | 2.549 | 2.584 | - | -| 100,000 | 50 | 7.030 | 7.038 | 1.00x | -| 500,000 | 50 | 26.646 | 26.500 | 1.01x | -| 10,000 | 200 | 10.033 | 10.245 | - | -| 100,000 | 200 | 28.587 | 28.589 | 1.00x | -| 500,000 | 200 | 115.022 | 115.005 | 1.00x | - -### Nested Message (对比 baseline) - -| num_rows | inner_fields | Baseline (ms) | P0+P3 (ms) | Speedup | -|----------|--------------|---------------|------------|---------| -| 10,000 | 5 | 0.699 | 0.683 | **1.02x** | -| 100,000 | 5 | 1.551 | 1.518 | **1.02x** | -| 500,000 | 5 | 4.787 | 4.584 | **1.04x** | -| 10,000 | 20 | 2.445 | 2.457 | 1.00x | -| 100,000 | 20 | 4.794 | 4.740 | **1.01x** | -| 500,000 | 20 | 13.995 | 13.798 | **1.01x** | -| 10,000 | 100 | 10.549 | 10.623 | 1.00x | -| 100,000 | 100 | 23.539 | 23.489 | 1.00x | -| 500,000 | 100 | 68.828 | 69.207 | 1.00x | - -### Repeated Fields (对比 baseline) - -| num_rows | avg_items | Baseline (ms) | P0+P3 (ms) | Speedup | -|----------|-----------|---------------|------------|---------| -| 10,000 | 1 | 0.966 | 0.839 | **1.15x** | -| 100,000 | 1 | 2.362 | 2.153 | **1.10x** | -| 10,000 | 5 | 1.292 | 1.148 | **1.13x** | -| 100,000 | 5 | 3.999 | 3.803 | **1.05x** | -| 10,000 | 20 | 2.221 | 2.080 | **1.07x** | -| 100,000 | 20 | 7.696 | 7.516 | **1.02x** | - -### Many Repeated Fields (对比 baseline) - -| num_rows | num_rep_fields | Baseline (ms) | P0+P3 (ms) | Speedup | -|----------|----------------|---------------|------------|---------| -| 10,000 | 10 | 1.891 | 1.472 | **1.28x** | -| 100,000 | 10 | 5.595 | 5.011 | **1.12x** | -| 10,000 | 30 | 5.937 | 4.676 | **1.27x** | -| 100,000 | 30 | 18.947 | 17.065 | **1.11x** | -| 10,000 | 50 | 9.851 | 8.137 | **1.21x** | -| 100,000 | 50 | 36.678 | 33.944 | **1.08x** | - -**结论**: P3 效果显著! -- **Many Repeated Fields**: 小行数 21-28% 提升,大行数 8-12% 提升 -- **Repeated Fields (3 fields)**: 7-15% 提升 -- **Flat/Nested**: Flat 几乎不变(无 repeated → 只省了 scan/required 2 次 sync),Nested 小幅提升 1-4% -- 总体来看 P0+P3 组合在 repeated-heavy schema 下带来了可观的加速 - -**UT**: 95/95 全过。 - ---- - -## 03. P1: Combined repeated occurrence scan kernel - -**日期**: 2026-03-04 -**改动摘要**: 新增 `scan_all_repeated_occurrences_kernel`,一次扫描消息就记录所有 repeated 字段的 occurrence(替代原来 N 个 repeated 字段各自独立做一次全消息扫描)。同时将 offset 计算全部前置并用 1 次 sync 代替 N 次。 -**改动文件**: protobuf.cu, protobuf_kernels.cu, protobuf_common.cuh -**原始输出**: `run_03_p1_combined_occurrence_scan.txt` - -### Many Repeated Fields (P1 的核心目标 case) - -| num_rows | num_rep_fields | Baseline (ms) | P0+P3 (ms) | P0+P3+P1 (ms) | vs Baseline | vs P0+P3 | -|----------|----------------|---------------|------------|----------------|-------------|----------| -| 10,000 | 10 | 1.891 | 1.472 | 1.638 | **1.15x** | 0.90x | -| 100,000 | 10 | 5.595 | 5.011 | 5.579 | 1.00x | 0.90x | -| 10,000 | 30 | 5.937 | 4.676 | 4.620 | **1.29x** | **1.01x** | -| 100,000 | 30 | 18.947 | 17.065 | 16.545 | **1.15x** | **1.03x** | -| 10,000 | 50 | 9.851 | 8.137 | 7.958 | **1.24x** | **1.02x** | -| 100,000 | 50 | 36.678 | 33.944 | 27.664 | **1.33x** | **1.23x** | - -### Repeated Fields (3 fields, 对比 baseline) - -| num_rows | avg_items | Baseline (ms) | P0+P3+P1 (ms) | Speedup | -|----------|-----------|---------------|----------------|---------| -| 10,000 | 1 | 0.966 | 0.853 | **1.13x** | -| 100,000 | 1 | 2.362 | 2.282 | **1.04x** | -| 10,000 | 5 | 1.292 | 1.177 | **1.10x** | -| 10,000 | 20 | 2.221 | 2.129 | **1.04x** | - -### Flat Scalars / Nested Message - -不受影响(符合预期,无顶层 repeated 字段)。 - -**结论**: -- **Many Repeated 100K×50: 36.68ms → 27.66ms (vs baseline 1.33x)**,这是三个优化叠加的效果 -- P1 单独贡献(vs P0+P3):100K×50 快了 **23%**(33.94→27.66),100K×30 快了 3% -- P1 的增益在字段数多 + 行数多时最显著:每多一个 repeated 字段,原来多一次全消息扫描,现在都合并为一次 -- 10K rows + 10 fields 时 P1 有微小退化(offset 前置分配的额外开销 > 合并扫描的收益) - -**UT**: 95/95 全过。 diff --git a/perf-results/run_00_baseline.txt b/perf-results/run_00_baseline.txt deleted file mode 100644 index e9109aa058..0000000000 --- a/perf-results/run_00_baseline.txt +++ /dev/null @@ -1,115 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] -Pass: Cold: 0.693088ms GPU, 0.696810ms CPU, 0.69s total GPU, 0.80s total wall, 992x -Run: [2/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] -Pass: Cold: 1.501602ms GPU, 1.505321ms CPU, 0.50s total GPU, 0.53s total wall, 333x -Run: [3/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] -Pass: Cold: 4.384572ms GPU, 4.388343ms CPU, 11.93s total GPU, 12.19s total wall, 2720x -Run: [4/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] -Pass: Cold: 2.548858ms GPU, 2.552612ms CPU, 5.91s total GPU, 6.15s total wall, 2320x -Run: [5/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] -Pass: Cold: 7.030410ms GPU, 7.034253ms CPU, 9.96s total GPU, 10.08s total wall, 1416x -Run: [6/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] -Pass: Cold: 26.646053ms GPU, 26.650230ms CPU, 0.51s total GPU, 0.51s total wall, 19x -Run: [7/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] -Pass: Cold: 10.032932ms GPU, 10.036624ms CPU, 0.50s total GPU, 0.51s total wall, 50x -Run: [8/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] -Pass: Cold: 28.587383ms GPU, 28.591340ms CPU, 0.51s total GPU, 0.52s total wall, 18x -Run: [9/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] -Pass: Cold: 115.021551ms GPU, 115.026983ms CPU, 1.27s total GPU, 1.27s total wall, 11x -Run: [10/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] -Pass: Cold: 0.698673ms GPU, 0.702294ms CPU, 0.50s total GPU, 0.58s total wall, 716x -Run: [11/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] -Pass: Cold: 1.550901ms GPU, 1.554820ms CPU, 1.39s total GPU, 1.48s total wall, 896x -Run: [12/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] -Pass: Cold: 4.786635ms GPU, 4.790429ms CPU, 13.17s total GPU, 13.44s total wall, 2752x -Run: [13/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] -Pass: Cold: 2.445035ms GPU, 2.448843ms CPU, 0.50s total GPU, 0.52s total wall, 205x -Run: [14/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] -Pass: Cold: 4.794299ms GPU, 4.798064ms CPU, 0.50s total GPU, 0.51s total wall, 105x -Run: [15/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] -Pass: Cold: 13.994625ms GPU, 13.998604ms CPU, 6.38s total GPU, 6.41s total wall, 456x -Run: [16/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] -Pass: Cold: 10.549263ms GPU, 10.553169ms CPU, 11.31s total GPU, 11.42s total wall, 1072x -Run: [17/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] -Warn: Current measurement timed out (15.00s) while over noise threshold (7.07% > 0.50%) -Pass: Cold: 23.539355ms GPU, 23.543439ms CPU, 14.95s total GPU, 15.00s total wall, 635x -Run: [18/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] -Pass: Cold: 68.827659ms GPU, 68.833052ms CPU, 0.76s total GPU, 0.76s total wall, 11x -Run: [19/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] -Pass: Cold: 0.965625ms GPU, 0.969249ms CPU, 1.15s total GPU, 1.29s total wall, 1196x -Run: [20/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] -Pass: Cold: 2.362006ms GPU, 2.365883ms CPU, 2.76s total GPU, 2.88s total wall, 1168x -Run: [21/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] -Pass: Cold: 1.291744ms GPU, 1.295457ms CPU, 1.47s total GPU, 1.59s total wall, 1141x -Run: [22/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] -Pass: Cold: 3.999376ms GPU, 4.003021ms CPU, 8.96s total GPU, 9.18s total wall, 2240x -Run: [23/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] -Pass: Cold: 2.220686ms GPU, 2.224453ms CPU, 0.50s total GPU, 0.52s total wall, 226x -Run: [24/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] -Pass: Cold: 7.696091ms GPU, 7.699906ms CPU, 0.50s total GPU, 0.51s total wall, 65x -``` - -# Benchmark Results - -## Protobuf Flat Scalars - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|------------|---------|------------|-------|------------|-------|--------| -| 10000 | 10 | 992x | 696.810 us | 0.51% | 693.088 us | 0.51% | 10000 | -| 100000 | 10 | 333x | 1.505 ms | 0.38% | 1.502 ms | 0.38% | 100000 | -| 500000 | 10 | 2720x | 4.388 ms | 0.57% | 4.385 ms | 0.57% | 500000 | -| 10000 | 50 | 2320x | 2.553 ms | 0.80% | 2.549 ms | 0.80% | 10000 | -| 100000 | 50 | 1416x | 7.034 ms | 0.50% | 7.030 ms | 0.50% | 100000 | -| 500000 | 50 | 19x | 26.650 ms | 0.15% | 26.646 ms | 0.15% | 500000 | -| 10000 | 200 | 50x | 10.037 ms | 0.17% | 10.033 ms | 0.17% | 10000 | -| 100000 | 200 | 18x | 28.591 ms | 0.22% | 28.587 ms | 0.22% | 100000 | -| 500000 | 200 | 11x | 115.027 ms | 0.11% | 115.022 ms | 0.11% | 500000 | - -## Protobuf Nested Message - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|--------------|---------|------------|-------|------------|-------|--------| -| 10000 | 5 | 716x | 702.294 us | 0.44% | 698.673 us | 0.44% | 10000 | -| 100000 | 5 | 896x | 1.555 ms | 0.82% | 1.551 ms | 0.82% | 100000 | -| 500000 | 5 | 2752x | 4.790 ms | 6.02% | 4.787 ms | 6.02% | 500000 | -| 10000 | 20 | 205x | 2.449 ms | 0.27% | 2.445 ms | 0.27% | 10000 | -| 100000 | 20 | 105x | 4.798 ms | 0.16% | 4.794 ms | 0.16% | 100000 | -| 500000 | 20 | 456x | 13.999 ms | 0.50% | 13.995 ms | 0.50% | 500000 | -| 10000 | 100 | 1072x | 10.553 ms | 2.57% | 10.549 ms | 2.57% | 10000 | -| 100000 | 100 | 635x | 23.543 ms | 7.07% | 23.539 ms | 7.07% | 100000 | -| 500000 | 100 | 11x | 68.833 ms | 0.33% | 68.828 ms | 0.33% | 500000 | - -## Protobuf Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|-----------|---------|------------|-------|------------|-------|--------| -| 10000 | 1 | 1196x | 969.249 us | 0.50% | 965.625 us | 0.50% | 10000 | -| 100000 | 1 | 1168x | 2.366 ms | 0.59% | 2.362 ms | 0.59% | 100000 | -| 10000 | 5 | 1141x | 1.295 ms | 0.50% | 1.292 ms | 0.50% | 10000 | -| 100000 | 5 | 2240x | 4.003 ms | 0.62% | 3.999 ms | 0.62% | 100000 | -| 10000 | 20 | 226x | 2.224 ms | 0.27% | 2.221 ms | 0.27% | 10000 | -| 100000 | 20 | 65x | 7.700 ms | 0.13% | 7.696 ms | 0.13% | 100000 | diff --git a/perf-results/run_00_baseline_many_repeated.txt b/perf-results/run_00_baseline_many_repeated.txt deleted file mode 100644 index a6b5e09a79..0000000000 --- a/perf-results/run_00_baseline_many_repeated.txt +++ /dev/null @@ -1,47 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] -Pass: Cold: 1.891084ms GPU, 1.894789ms CPU, 1.03s total GPU, 1.09s total wall, 544x -Run: [2/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] -Pass: Cold: 5.595138ms GPU, 5.598711ms CPU, 6.90s total GPU, 7.03s total wall, 1234x -Run: [3/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] -Warn: Current measurement timed out (15.01s) while over noise threshold (1.14% > 0.50%) -Pass: Cold: 5.936679ms GPU, 5.940334ms CPU, 14.75s total GPU, 15.01s total wall, 2485x -Run: [4/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] -Pass: Cold: 18.947469ms GPU, 18.951296ms CPU, 4.62s total GPU, 4.64s total wall, 244x -Run: [5/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] -Pass: Cold: 9.850844ms GPU, 9.854526ms CPU, 0.50s total GPU, 0.51s total wall, 51x -Run: [6/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] -Pass: Cold: 36.678007ms GPU, 36.682032ms CPU, 0.51s total GPU, 0.51s total wall, 14x -``` - -# Benchmark Results - -## Protobuf Many Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|----------------|---------|-----------|-------|-----------|-------|--------| -| 10000 | 10 | 544x | 1.895 ms | 0.68% | 1.891 ms | 0.68% | 10000 | -| 100000 | 10 | 1234x | 5.599 ms | 0.50% | 5.595 ms | 0.50% | 100000 | -| 10000 | 30 | 2485x | 5.940 ms | 1.14% | 5.937 ms | 1.14% | 10000 | -| 100000 | 30 | 244x | 18.951 ms | 0.50% | 18.947 ms | 0.50% | 100000 | -| 10000 | 50 | 51x | 9.855 ms | 0.17% | 9.851 ms | 0.17% | 10000 | -| 100000 | 50 | 14x | 36.682 ms | 0.13% | 36.678 ms | 0.13% | 100000 | diff --git a/perf-results/run_01_p0_batch_repeated_sync.txt b/perf-results/run_01_p0_batch_repeated_sync.txt deleted file mode 100644 index 666d6829ff..0000000000 --- a/perf-results/run_01_p0_batch_repeated_sync.txt +++ /dev/null @@ -1,118 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] -Pass: Cold: 0.717111ms GPU, 0.720729ms CPU, 1.54s total GPU, 1.79s total wall, 2144x -Run: [2/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] -Pass: Cold: 1.529744ms GPU, 1.533401ms CPU, 0.50s total GPU, 0.53s total wall, 327x -Run: [3/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] -Pass: Cold: 4.439003ms GPU, 4.442652ms CPU, 6.75s total GPU, 6.89s total wall, 1520x -Run: [4/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] -Pass: Cold: 2.655585ms GPU, 2.659248ms CPU, 6.20s total GPU, 6.44s total wall, 2336x -Run: [5/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] -Warn: Current measurement timed out (15.01s) while over noise threshold (5.03% > 0.50%) -Pass: Cold: 7.199181ms GPU, 7.202834ms CPU, 14.82s total GPU, 15.01s total wall, 2058x -Run: [6/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] -Pass: Cold: 26.918621ms GPU, 26.922675ms CPU, 0.51s total GPU, 0.51s total wall, 19x -Run: [7/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] -Pass: Cold: 10.300780ms GPU, 10.304434ms CPU, 0.50s total GPU, 0.51s total wall, 49x -Run: [8/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] -Pass: Cold: 29.027086ms GPU, 29.031588ms CPU, 0.52s total GPU, 0.52s total wall, 18x -Run: [9/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] -Pass: Cold: 116.469964ms GPU, 116.475358ms CPU, 1.28s total GPU, 1.28s total wall, 11x -Run: [10/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] -Pass: Cold: 0.723146ms GPU, 0.726771ms CPU, 0.50s total GPU, 0.58s total wall, 692x -Run: [11/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] -Pass: Cold: 1.589212ms GPU, 1.593000ms CPU, 0.50s total GPU, 0.53s total wall, 315x -Run: [12/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] -Pass: Cold: 4.805955ms GPU, 4.809632ms CPU, 6.38s total GPU, 6.51s total wall, 1328x -Run: [13/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] -Pass: Cold: 2.522655ms GPU, 2.526255ms CPU, 0.50s total GPU, 0.52s total wall, 199x -Run: [14/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] -Pass: Cold: 4.899035ms GPU, 4.902721ms CPU, 0.50s total GPU, 0.51s total wall, 103x -Run: [15/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] -Warn: Current measurement timed out (15.01s) while over noise threshold (1.47% > 0.50%) -Pass: Cold: 14.223304ms GPU, 14.227113ms CPU, 14.93s total GPU, 15.01s total wall, 1050x -Run: [16/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] -Pass: Cold: 11.633095ms GPU, 11.636902ms CPU, 5.96s total GPU, 6.01s total wall, 512x -Run: [17/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] -Warn: Current measurement timed out (15.02s) while over noise threshold (3.23% > 0.50%) -Pass: Cold: 24.105674ms GPU, 24.109725ms CPU, 14.97s total GPU, 15.02s total wall, 621x -Run: [18/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] -Pass: Cold: 69.203103ms GPU, 69.208290ms CPU, 0.76s total GPU, 0.76s total wall, 11x -Run: [19/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] -Pass: Cold: 1.898922ms GPU, 1.902667ms CPU, 0.50s total GPU, 0.53s total wall, 264x -Run: [20/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] -Pass: Cold: 3.314200ms GPU, 3.318216ms CPU, 2.81s total GPU, 2.90s total wall, 848x -Run: [21/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] -Pass: Cold: 2.205836ms GPU, 2.209719ms CPU, 0.50s total GPU, 0.52s total wall, 227x -Run: [22/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] -Pass: Cold: 5.259358ms GPU, 5.263366ms CPU, 0.50s total GPU, 0.51s total wall, 96x -Run: [23/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] -Pass: Cold: 3.261910ms GPU, 3.265949ms CPU, 1.90s total GPU, 1.96s total wall, 583x -Run: [24/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] -Warn: Current measurement timed out (15.01s) while over noise threshold (0.85% > 0.50%) -Pass: Cold: 8.872888ms GPU, 8.877022ms CPU, 14.87s total GPU, 15.01s total wall, 1676x -``` - -# Benchmark Results - -## Protobuf Flat Scalars - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|------------|---------|------------|-------|------------|-------|--------| -| 10000 | 10 | 2144x | 720.729 us | 1.61% | 717.111 us | 1.62% | 10000 | -| 100000 | 10 | 327x | 1.533 ms | 0.36% | 1.530 ms | 0.36% | 100000 | -| 500000 | 10 | 1520x | 4.443 ms | 0.71% | 4.439 ms | 0.71% | 500000 | -| 10000 | 50 | 2336x | 2.659 ms | 2.30% | 2.656 ms | 2.30% | 10000 | -| 100000 | 50 | 2058x | 7.203 ms | 5.03% | 7.199 ms | 5.03% | 100000 | -| 500000 | 50 | 19x | 26.923 ms | 0.09% | 26.919 ms | 0.09% | 500000 | -| 10000 | 200 | 49x | 10.304 ms | 0.29% | 10.301 ms | 0.29% | 10000 | -| 100000 | 200 | 18x | 29.032 ms | 0.13% | 29.027 ms | 0.13% | 100000 | -| 500000 | 200 | 11x | 116.475 ms | 0.37% | 116.470 ms | 0.37% | 500000 | - -## Protobuf Nested Message - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|--------------|---------|------------|-------|------------|-------|--------| -| 10000 | 5 | 692x | 726.771 us | 0.50% | 723.146 us | 0.50% | 10000 | -| 100000 | 5 | 315x | 1.593 ms | 0.34% | 1.589 ms | 0.34% | 100000 | -| 500000 | 5 | 1328x | 4.810 ms | 0.57% | 4.806 ms | 0.57% | 500000 | -| 10000 | 20 | 199x | 2.526 ms | 0.27% | 2.523 ms | 0.27% | 10000 | -| 100000 | 20 | 103x | 4.903 ms | 0.29% | 4.899 ms | 0.29% | 100000 | -| 500000 | 20 | 1050x | 14.227 ms | 1.46% | 14.223 ms | 1.47% | 500000 | -| 10000 | 100 | 512x | 11.637 ms | 5.12% | 11.633 ms | 5.12% | 10000 | -| 100000 | 100 | 621x | 24.110 ms | 3.23% | 24.106 ms | 3.23% | 100000 | -| 500000 | 100 | 11x | 69.208 ms | 0.42% | 69.203 ms | 0.42% | 500000 | - -## Protobuf Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|-----------|---------|----------|-------|----------|-------|--------| -| 10000 | 1 | 264x | 1.903 ms | 0.48% | 1.899 ms | 0.48% | 10000 | -| 100000 | 1 | 848x | 3.318 ms | 0.71% | 3.314 ms | 0.71% | 100000 | -| 10000 | 5 | 227x | 2.210 ms | 0.29% | 2.206 ms | 0.29% | 10000 | -| 100000 | 5 | 96x | 5.263 ms | 0.19% | 5.259 ms | 0.19% | 100000 | -| 10000 | 20 | 583x | 3.266 ms | 0.50% | 3.262 ms | 0.50% | 10000 | -| 100000 | 20 | 1676x | 8.877 ms | 0.85% | 8.873 ms | 0.85% | 100000 | diff --git a/perf-results/run_01_p0_many_repeated.txt b/perf-results/run_01_p0_many_repeated.txt deleted file mode 100644 index ce41520b18..0000000000 --- a/perf-results/run_01_p0_many_repeated.txt +++ /dev/null @@ -1,48 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] -Pass: Cold: 1.748073ms GPU, 1.751753ms CPU, 1.62s total GPU, 1.73s total wall, 928x -Run: [2/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] -Warn: Current measurement timed out (15.00s) while over noise threshold (4.72% > 0.50%) -Pass: Cold: 5.443364ms GPU, 5.446938ms CPU, 14.72s total GPU, 15.00s total wall, 2705x -Run: [3/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] -Pass: Cold: 5.374920ms GPU, 5.378463ms CPU, 4.13s total GPU, 4.21s total wall, 768x -Run: [4/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] -Warn: Current measurement timed out (15.02s) while over noise threshold (0.65% > 0.50%) -Pass: Cold: 18.527305ms GPU, 18.531079ms CPU, 14.95s total GPU, 15.02s total wall, 807x -Run: [5/6] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] -Pass: Cold: 9.055705ms GPU, 9.059334ms CPU, 0.51s total GPU, 0.51s total wall, 56x -Run: [6/6] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] -Pass: Cold: 35.969858ms GPU, 35.973699ms CPU, 0.50s total GPU, 0.50s total wall, 14x -``` - -# Benchmark Results - -## Protobuf Many Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|----------------|---------|-----------|-------|-----------|-------|--------| -| 10000 | 10 | 928x | 1.752 ms | 0.80% | 1.748 ms | 0.80% | 10000 | -| 100000 | 10 | 2705x | 5.447 ms | 4.72% | 5.443 ms | 4.72% | 100000 | -| 10000 | 30 | 768x | 5.378 ms | 3.05% | 5.375 ms | 3.05% | 10000 | -| 100000 | 30 | 807x | 18.531 ms | 0.65% | 18.527 ms | 0.65% | 100000 | -| 10000 | 50 | 56x | 9.059 ms | 0.25% | 9.056 ms | 0.25% | 10000 | -| 100000 | 50 | 14x | 35.974 ms | 0.19% | 35.970 ms | 0.19% | 100000 | diff --git a/perf-results/run_01_p0_merge_reduce_scan.txt b/perf-results/run_01_p0_merge_reduce_scan.txt deleted file mode 100644 index 9f4d4e8c93..0000000000 --- a/perf-results/run_01_p0_merge_reduce_scan.txt +++ /dev/null @@ -1,117 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] -Pass: Cold: 0.726904ms GPU, 0.730745ms CPU, 1.80s total GPU, 2.10s total wall, 2480x -Run: [2/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] -Pass: Cold: 1.536747ms GPU, 1.540522ms CPU, 1.18s total GPU, 1.26s total wall, 768x -Run: [3/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] -Pass: Cold: 4.443912ms GPU, 4.447723ms CPU, 14.43s total GPU, 14.76s total wall, 3248x -Run: [4/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] -Pass: Cold: 2.726921ms GPU, 2.730759ms CPU, 5.15s total GPU, 5.34s total wall, 1888x -Run: [5/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] -Warn: Current measurement timed out (15.01s) while over noise threshold (0.90% > 0.50%) -Pass: Cold: 7.246037ms GPU, 7.249840ms CPU, 14.81s total GPU, 15.01s total wall, 2044x -Run: [6/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] -Pass: Cold: 26.963141ms GPU, 26.967804ms CPU, 0.51s total GPU, 0.51s total wall, 19x -Run: [7/24] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] -Pass: Cold: 10.912177ms GPU, 10.916118ms CPU, 9.08s total GPU, 9.17s total wall, 832x -Run: [8/24] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] -Warn: Current measurement timed out (15.02s) while over noise threshold (0.86% > 0.50%) -Pass: Cold: 29.366643ms GPU, 29.370824ms CPU, 14.98s total GPU, 15.02s total wall, 510x -Run: [9/24] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] -Pass: Cold: 116.532303ms GPU, 116.538900ms CPU, 1.28s total GPU, 1.28s total wall, 11x -Run: [10/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] -Pass: Cold: 0.749640ms GPU, 0.753178ms CPU, 1.09s total GPU, 1.26s total wall, 1456x -Run: [11/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] -Pass: Cold: 1.638890ms GPU, 1.642742ms CPU, 4.20s total GPU, 4.46s total wall, 2560x -Run: [12/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] -Pass: Cold: 4.810134ms GPU, 4.814028ms CPU, 12.24s total GPU, 12.49s total wall, 2544x -Run: [13/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] -Pass: Cold: 2.629933ms GPU, 2.633856ms CPU, 2.99s total GPU, 3.11s total wall, 1136x -Run: [14/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] -Pass: Cold: 5.105757ms GPU, 5.109912ms CPU, 4.74s total GPU, 4.83s total wall, 928x -Run: [15/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] -Pass: Cold: 14.065493ms GPU, 14.069656ms CPU, 6.64s total GPU, 6.67s total wall, 472x -Run: [16/24] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] -Pass: Cold: 11.388964ms GPU, 11.393028ms CPU, 7.29s total GPU, 7.35s total wall, 640x -Run: [17/24] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] -Warn: Current measurement timed out (15.00s) while over noise threshold (1.33% > 0.50%) -Pass: Cold: 24.113780ms GPU, 24.118162ms CPU, 14.95s total GPU, 15.00s total wall, 620x -Run: [18/24] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] -Pass: Cold: 69.461100ms GPU, 69.466652ms CPU, 0.76s total GPU, 0.76s total wall, 11x -Run: [19/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] -Pass: Cold: 0.944527ms GPU, 0.948112ms CPU, 0.50s total GPU, 0.56s total wall, 530x -Run: [20/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] -Pass: Cold: 2.308945ms GPU, 2.312709ms CPU, 0.50s total GPU, 0.52s total wall, 217x -Run: [21/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] -Pass: Cold: 1.271388ms GPU, 1.275154ms CPU, 2.85s total GPU, 3.09s total wall, 2240x -Run: [22/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] -Pass: Cold: 3.967131ms GPU, 3.970874ms CPU, 7.30s total GPU, 7.49s total wall, 1840x -Run: [23/24] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] -Pass: Cold: 2.220037ms GPU, 2.223704ms CPU, 0.50s total GPU, 0.52s total wall, 226x -Run: [24/24] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] -Pass: Cold: 7.729027ms GPU, 7.733813ms CPU, 0.50s total GPU, 0.51s total wall, 65x -``` - -# Benchmark Results - -## Protobuf Flat Scalars - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|------------|---------|------------|-------|------------|-------|--------| -| 10000 | 10 | 2480x | 730.745 us | 2.70% | 726.904 us | 2.72% | 10000 | -| 100000 | 10 | 768x | 1.541 ms | 1.42% | 1.537 ms | 1.43% | 100000 | -| 500000 | 10 | 3248x | 4.448 ms | 6.56% | 4.444 ms | 6.56% | 500000 | -| 10000 | 50 | 1888x | 2.731 ms | 1.65% | 2.727 ms | 1.65% | 10000 | -| 100000 | 50 | 2044x | 7.250 ms | 0.90% | 7.246 ms | 0.90% | 100000 | -| 500000 | 50 | 19x | 26.968 ms | 0.20% | 26.963 ms | 0.20% | 500000 | -| 10000 | 200 | 832x | 10.916 ms | 2.51% | 10.912 ms | 2.51% | 10000 | -| 100000 | 200 | 510x | 29.371 ms | 0.86% | 29.367 ms | 0.86% | 100000 | -| 500000 | 200 | 11x | 116.539 ms | 0.14% | 116.532 ms | 0.14% | 500000 | - -## Protobuf Nested Message - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|--------------|---------|------------|-------|------------|-------|--------| -| 10000 | 5 | 1456x | 753.178 us | 1.51% | 749.640 us | 1.52% | 10000 | -| 100000 | 5 | 2560x | 1.643 ms | 0.90% | 1.639 ms | 0.89% | 100000 | -| 500000 | 5 | 2544x | 4.814 ms | 1.55% | 4.810 ms | 1.55% | 500000 | -| 10000 | 20 | 1136x | 2.634 ms | 1.91% | 2.630 ms | 1.91% | 10000 | -| 100000 | 20 | 928x | 5.110 ms | 1.47% | 5.106 ms | 1.47% | 100000 | -| 500000 | 20 | 472x | 14.070 ms | 0.50% | 14.065 ms | 0.50% | 500000 | -| 10000 | 100 | 640x | 11.393 ms | 4.83% | 11.389 ms | 4.83% | 10000 | -| 100000 | 100 | 620x | 24.118 ms | 1.33% | 24.114 ms | 1.33% | 100000 | -| 500000 | 100 | 11x | 69.467 ms | 0.14% | 69.461 ms | 0.14% | 500000 | - -## Protobuf Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|-----------|---------|------------|-------|------------|-------|--------| -| 10000 | 1 | 530x | 948.112 us | 0.35% | 944.527 us | 0.35% | 10000 | -| 100000 | 1 | 217x | 2.313 ms | 0.27% | 2.309 ms | 0.27% | 100000 | -| 10000 | 5 | 2240x | 1.275 ms | 1.13% | 1.271 ms | 1.13% | 10000 | -| 100000 | 5 | 1840x | 3.971 ms | 0.57% | 3.967 ms | 0.57% | 100000 | -| 10000 | 20 | 226x | 2.224 ms | 0.32% | 2.220 ms | 0.32% | 10000 | -| 100000 | 20 | 65x | 7.734 ms | 0.13% | 7.729 ms | 0.13% | 100000 | diff --git a/perf-results/run_02_p3_defer_error_check.txt b/perf-results/run_02_p3_defer_error_check.txt deleted file mode 100644 index dd47a5ba1b..0000000000 --- a/perf-results/run_02_p3_defer_error_check.txt +++ /dev/null @@ -1,139 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] -Pass: Cold: 0.692966ms GPU, 0.696617ms CPU, 1.88s total GPU, 2.20s total wall, 2720x -Run: [2/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] -Pass: Cold: 1.491170ms GPU, 1.494817ms CPU, 4.75s total GPU, 5.07s total wall, 3184x -Run: [3/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] -Pass: Cold: 4.211614ms GPU, 4.215283ms CPU, 11.32s total GPU, 11.58s total wall, 2688x -Run: [4/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] -Pass: Cold: 2.584040ms GPU, 2.587704ms CPU, 8.14s total GPU, 8.47s total wall, 3152x -Run: [5/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] -Pass: Cold: 7.038277ms GPU, 7.041922ms CPU, 5.58s total GPU, 5.65s total wall, 793x -Run: [6/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] -Pass: Cold: 26.499879ms GPU, 26.504024ms CPU, 0.50s total GPU, 0.50s total wall, 19x -Run: [7/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] -Pass: Cold: 10.245305ms GPU, 10.249014ms CPU, 0.50s total GPU, 0.51s total wall, 49x -Run: [8/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] -Pass: Cold: 28.589312ms GPU, 28.593149ms CPU, 7.69s total GPU, 7.71s total wall, 269x -Run: [9/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] -Pass: Cold: 115.004578ms GPU, 115.010221ms CPU, 1.27s total GPU, 1.27s total wall, 11x -Run: [10/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] -Pass: Cold: 0.682600ms GPU, 0.686183ms CPU, 0.67s total GPU, 0.78s total wall, 976x -Run: [11/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] -Pass: Cold: 1.518160ms GPU, 1.522011ms CPU, 4.18s total GPU, 4.45s total wall, 2752x -Run: [12/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] -Pass: Cold: 4.583684ms GPU, 4.587425ms CPU, 9.11s total GPU, 9.30s total wall, 1987x -Run: [13/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] -Pass: Cold: 2.456863ms GPU, 2.460561ms CPU, 0.50s total GPU, 0.52s total wall, 204x -Run: [14/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] -Pass: Cold: 4.740401ms GPU, 4.744170ms CPU, 0.50s total GPU, 0.51s total wall, 106x -Run: [15/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] -Pass: Cold: 13.798471ms GPU, 13.802428ms CPU, 14.52s total GPU, 14.59s total wall, 1052x -Run: [16/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] -Pass: Cold: 10.622611ms GPU, 10.626416ms CPU, 0.51s total GPU, 0.51s total wall, 48x -Run: [17/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] -Pass: Cold: 23.489448ms GPU, 23.493497ms CPU, 13.15s total GPU, 13.20s total wall, 560x -Run: [18/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] -Pass: Cold: 69.207141ms GPU, 69.212410ms CPU, 5.74s total GPU, 5.75s total wall, 83x -Run: [19/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] -Pass: Cold: 0.839203ms GPU, 0.842775ms CPU, 0.50s total GPU, 0.57s total wall, 596x -Run: [20/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] -Pass: Cold: 2.152807ms GPU, 2.156525ms CPU, 0.50s total GPU, 0.52s total wall, 233x -Run: [21/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] -Pass: Cold: 1.148301ms GPU, 1.151898ms CPU, 0.50s total GPU, 0.54s total wall, 436x -Run: [22/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] -Pass: Cold: 3.803068ms GPU, 3.806759ms CPU, 7.05s total GPU, 7.23s total wall, 1853x -Run: [23/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] -Pass: Cold: 2.080413ms GPU, 2.084091ms CPU, 0.50s total GPU, 0.53s total wall, 241x -Run: [24/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] -Pass: Cold: 7.516108ms GPU, 7.520008ms CPU, 0.50s total GPU, 0.51s total wall, 67x -Run: [25/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] -Pass: Cold: 1.471735ms GPU, 1.475218ms CPU, 0.50s total GPU, 0.54s total wall, 340x -Run: [26/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] -Pass: Cold: 5.010848ms GPU, 5.014496ms CPU, 0.50s total GPU, 0.51s total wall, 100x -Run: [27/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] -Pass: Cold: 4.675919ms GPU, 4.679577ms CPU, 0.50s total GPU, 0.51s total wall, 107x -Run: [28/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] -Pass: Cold: 17.064845ms GPU, 17.068522ms CPU, 0.51s total GPU, 0.51s total wall, 30x -Run: [29/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] -Pass: Cold: 8.136901ms GPU, 8.140626ms CPU, 0.50s total GPU, 0.51s total wall, 62x -Run: [30/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] -Pass: Cold: 33.943523ms GPU, 33.947271ms CPU, 0.51s total GPU, 0.51s total wall, 15x -``` - -# Benchmark Results - -## Protobuf Flat Scalars - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|------------|---------|------------|-------|------------|-------|--------| -| 10000 | 10 | 2720x | 696.617 us | 1.30% | 692.966 us | 1.31% | 10000 | -| 100000 | 10 | 3184x | 1.495 ms | 5.80% | 1.491 ms | 5.81% | 100000 | -| 500000 | 10 | 2688x | 4.215 ms | 0.55% | 4.212 ms | 0.55% | 500000 | -| 10000 | 50 | 3152x | 2.588 ms | 1.31% | 2.584 ms | 1.31% | 10000 | -| 100000 | 50 | 793x | 7.042 ms | 0.50% | 7.038 ms | 0.50% | 100000 | -| 500000 | 50 | 19x | 26.504 ms | 0.07% | 26.500 ms | 0.07% | 500000 | -| 10000 | 200 | 49x | 10.249 ms | 0.34% | 10.245 ms | 0.34% | 10000 | -| 100000 | 200 | 269x | 28.593 ms | 0.50% | 28.589 ms | 0.50% | 100000 | -| 500000 | 200 | 11x | 115.010 ms | 0.32% | 115.005 ms | 0.32% | 500000 | - -## Protobuf Nested Message - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|--------------|---------|------------|-------|------------|-------|--------| -| 10000 | 5 | 976x | 686.183 us | 0.50% | 682.600 us | 0.51% | 10000 | -| 100000 | 5 | 2752x | 1.522 ms | 0.86% | 1.518 ms | 0.86% | 100000 | -| 500000 | 5 | 1987x | 4.587 ms | 0.50% | 4.584 ms | 0.50% | 500000 | -| 10000 | 20 | 204x | 2.461 ms | 0.24% | 2.457 ms | 0.24% | 10000 | -| 100000 | 20 | 106x | 4.744 ms | 0.15% | 4.740 ms | 0.15% | 100000 | -| 500000 | 20 | 1052x | 13.802 ms | 0.50% | 13.798 ms | 0.50% | 500000 | -| 10000 | 100 | 48x | 10.626 ms | 0.27% | 10.623 ms | 0.27% | 10000 | -| 100000 | 100 | 560x | 23.493 ms | 1.13% | 23.489 ms | 1.13% | 100000 | -| 500000 | 100 | 83x | 69.212 ms | 0.50% | 69.207 ms | 0.50% | 500000 | - -## Protobuf Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|-----------|---------|------------|-------|------------|-------|--------| -| 10000 | 1 | 596x | 842.775 us | 0.35% | 839.203 us | 0.35% | 10000 | -| 100000 | 1 | 233x | 2.157 ms | 0.27% | 2.153 ms | 0.27% | 100000 | -| 10000 | 5 | 436x | 1.152 ms | 0.47% | 1.148 ms | 0.47% | 10000 | -| 100000 | 5 | 1853x | 3.807 ms | 0.50% | 3.803 ms | 0.50% | 100000 | -| 10000 | 20 | 241x | 2.084 ms | 0.23% | 2.080 ms | 0.24% | 10000 | -| 100000 | 20 | 67x | 7.520 ms | 0.13% | 7.516 ms | 0.13% | 100000 | - -## Protobuf Many Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|----------------|---------|-----------|-------|-----------|-------|--------| -| 10000 | 10 | 340x | 1.475 ms | 0.27% | 1.472 ms | 0.27% | 10000 | -| 100000 | 10 | 100x | 5.014 ms | 0.49% | 5.011 ms | 0.49% | 100000 | -| 10000 | 30 | 107x | 4.680 ms | 0.14% | 4.676 ms | 0.14% | 10000 | -| 100000 | 30 | 30x | 17.069 ms | 0.13% | 17.065 ms | 0.13% | 100000 | -| 10000 | 50 | 62x | 8.141 ms | 0.17% | 8.137 ms | 0.17% | 10000 | -| 100000 | 50 | 15x | 33.947 ms | 0.09% | 33.944 ms | 0.09% | 100000 | diff --git a/perf-results/run_03_p1_combined_occurrence_scan.txt b/perf-results/run_03_p1_combined_occurrence_scan.txt deleted file mode 100644 index 85e244b8fe..0000000000 --- a/perf-results/run_03_p1_combined_occurrence_scan.txt +++ /dev/null @@ -1,143 +0,0 @@ -# Devices - -## [0] `NVIDIA RTX 5880 Ada Generation` -* SM Version: 890 (PTX Version: 860) -* Number of SMs: 110 -* SM Default Clock Rate: 2460 MHz -* Global Memory: 45660 MiB Free / 48506 MiB Total -* Global Memory Bus Peak: 960 GB/sec (384-bit DDR @10001MHz) -* Max Shared Memory: 100 KiB/SM, 48 KiB/Block -* L2 Cache Size: 98304 KiB -* Maximum Active Blocks: 24/SM -* Maximum Active Threads: 1536/SM, 1024/Block -* Available Registers: 65536/SM, 65536/Block -* ECC Enabled: No - -# Log - -``` -Run: [1/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=10] -Pass: Cold: 0.686546ms GPU, 0.690229ms CPU, 1.96s total GPU, 2.29s total wall, 2848x -Run: [2/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=10] -Pass: Cold: 1.501841ms GPU, 1.505566ms CPU, 0.50s total GPU, 0.53s total wall, 333x -Run: [3/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=10] -Pass: Cold: 4.243665ms GPU, 4.247329ms CPU, 11.81s total GPU, 12.09s total wall, 2784x -Run: [4/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=50] -Pass: Cold: 2.617960ms GPU, 2.621655ms CPU, 7.04s total GPU, 7.31s total wall, 2688x -Run: [5/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=50] -Pass: Cold: 7.109102ms GPU, 7.112786ms CPU, 9.78s total GPU, 9.91s total wall, 1376x -Run: [6/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=50] -Pass: Cold: 26.655070ms GPU, 26.659158ms CPU, 0.51s total GPU, 0.51s total wall, 19x -Run: [7/30] Protobuf Flat Scalars [Device=0 num_rows=10000 num_fields=200] -Pass: Cold: 10.439957ms GPU, 10.443692ms CPU, 11.36s total GPU, 11.47s total wall, 1088x -Run: [8/30] Protobuf Flat Scalars [Device=0 num_rows=100000 num_fields=200] -Warn: Current measurement timed out (15.02s) while over noise threshold (1.48% > 0.50%) -Pass: Cold: 29.198385ms GPU, 29.202193ms CPU, 14.98s total GPU, 15.02s total wall, 513x -Run: [9/30] Protobuf Flat Scalars [Device=0 num_rows=500000 num_fields=200] -Pass: Cold: 115.994634ms GPU, 115.999994ms CPU, 7.42s total GPU, 7.43s total wall, 64x -Run: [10/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=5] -Pass: Cold: 0.682638ms GPU, 0.686217ms CPU, 1.07s total GPU, 1.25s total wall, 1568x -Run: [11/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=5] -Pass: Cold: 1.538363ms GPU, 1.542351ms CPU, 4.48s total GPU, 4.77s total wall, 2912x -Run: [12/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=5] -Pass: Cold: 4.651629ms GPU, 4.655404ms CPU, 13.02s total GPU, 13.30s total wall, 2800x -Run: [13/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=20] -Pass: Cold: 2.458447ms GPU, 2.462209ms CPU, 0.50s total GPU, 0.52s total wall, 204x -Run: [14/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=20] -Pass: Cold: 4.817362ms GPU, 4.821159ms CPU, 0.50s total GPU, 0.51s total wall, 104x -Run: [15/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=20] -Warn: Current measurement timed out (15.00s) while over noise threshold (3.64% > 0.50%) -Pass: Cold: 13.952591ms GPU, 13.956618ms CPU, 14.93s total GPU, 15.00s total wall, 1070x -Run: [16/30] Protobuf Nested Message [Device=0 num_rows=10000 inner_fields=100] -Pass: Cold: 11.263608ms GPU, 11.267472ms CPU, 5.95s total GPU, 6.00s total wall, 528x -Run: [17/30] Protobuf Nested Message [Device=0 num_rows=100000 inner_fields=100] -Pass: Cold: 24.756087ms GPU, 24.760204ms CPU, 13.86s total GPU, 13.91s total wall, 560x -Run: [18/30] Protobuf Nested Message [Device=0 num_rows=500000 inner_fields=100] -Warn: Current measurement timed out (15.00s) while over noise threshold (0.91% > 0.50%) -Pass: Cold: 70.025400ms GPU, 70.030674ms CPU, 14.99s total GPU, 15.00s total wall, 214x -Run: [19/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=1] -Pass: Cold: 0.852793ms GPU, 0.856320ms CPU, 0.50s total GPU, 0.57s total wall, 587x -Run: [20/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=1] -Pass: Cold: 2.281557ms GPU, 2.285314ms CPU, 0.50s total GPU, 0.52s total wall, 220x -Run: [21/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=5] -Pass: Cold: 1.177153ms GPU, 1.180749ms CPU, 1.45s total GPU, 1.58s total wall, 1232x -Run: [22/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=5] -Pass: Cold: 4.102615ms GPU, 4.106303ms CPU, 9.32s total GPU, 9.55s total wall, 2272x -Run: [23/30] Protobuf Repeated Fields [Device=0 num_rows=10000 avg_items=20] -Pass: Cold: 2.129240ms GPU, 2.132977ms CPU, 0.50s total GPU, 0.52s total wall, 235x -Run: [24/30] Protobuf Repeated Fields [Device=0 num_rows=100000 avg_items=20] -Pass: Cold: 7.738479ms GPU, 7.742156ms CPU, 0.50s total GPU, 0.51s total wall, 65x -Run: [25/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=10] -Pass: Cold: 1.637901ms GPU, 1.641626ms CPU, 0.50s total GPU, 0.53s total wall, 306x -Run: [26/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=10] -Pass: Cold: 5.579029ms GPU, 5.582809ms CPU, 4.64s total GPU, 4.73s total wall, 832x -Run: [27/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=30] -Pass: Cold: 4.620287ms GPU, 4.623976ms CPU, 0.50s total GPU, 0.51s total wall, 109x -Run: [28/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=30] -Pass: Cold: 16.545099ms GPU, 16.548836ms CPU, 12.44s total GPU, 12.51s total wall, 752x -Run: [29/30] Protobuf Many Repeated Fields [Device=0 num_rows=10000 num_rep_fields=50] -Pass: Cold: 7.957812ms GPU, 7.961486ms CPU, 4.46s total GPU, 4.51s total wall, 560x -Run: [30/30] Protobuf Many Repeated Fields [Device=0 num_rows=100000 num_rep_fields=50] -Warn: Current measurement timed out (15.01s) while over noise threshold (0.69% > 0.50%) -Pass: Cold: 27.663957ms GPU, 27.667840ms CPU, 14.97s total GPU, 15.01s total wall, 541x -``` - -# Benchmark Results - -## Protobuf Flat Scalars - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|------------|---------|------------|-------|------------|-------|--------| -| 10000 | 10 | 2848x | 690.229 us | 1.41% | 686.546 us | 1.41% | 10000 | -| 100000 | 10 | 333x | 1.506 ms | 0.35% | 1.502 ms | 0.35% | 100000 | -| 500000 | 10 | 2784x | 4.247 ms | 0.73% | 4.244 ms | 0.73% | 500000 | -| 10000 | 50 | 2688x | 2.622 ms | 6.22% | 2.618 ms | 6.23% | 10000 | -| 100000 | 50 | 1376x | 7.113 ms | 1.14% | 7.109 ms | 1.14% | 100000 | -| 500000 | 50 | 19x | 26.659 ms | 0.06% | 26.655 ms | 0.06% | 500000 | -| 10000 | 200 | 1088x | 10.444 ms | 2.81% | 10.440 ms | 2.81% | 10000 | -| 100000 | 200 | 513x | 29.202 ms | 1.48% | 29.198 ms | 1.48% | 100000 | -| 500000 | 200 | 64x | 116.000 ms | 0.50% | 115.995 ms | 0.50% | 500000 | - -## Protobuf Nested Message - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | inner_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|--------------|---------|------------|-------|------------|-------|--------| -| 10000 | 5 | 1568x | 686.217 us | 0.53% | 682.638 us | 0.54% | 10000 | -| 100000 | 5 | 2912x | 1.542 ms | 1.12% | 1.538 ms | 1.13% | 100000 | -| 500000 | 5 | 2800x | 4.655 ms | 0.55% | 4.652 ms | 0.55% | 500000 | -| 10000 | 20 | 204x | 2.462 ms | 0.23% | 2.458 ms | 0.23% | 10000 | -| 100000 | 20 | 104x | 4.821 ms | 0.28% | 4.817 ms | 0.28% | 100000 | -| 500000 | 20 | 1070x | 13.957 ms | 3.64% | 13.953 ms | 3.64% | 500000 | -| 10000 | 100 | 528x | 11.267 ms | 5.07% | 11.264 ms | 5.07% | 10000 | -| 100000 | 100 | 560x | 24.760 ms | 2.25% | 24.756 ms | 2.25% | 100000 | -| 500000 | 100 | 214x | 70.031 ms | 0.91% | 70.025 ms | 0.91% | 500000 | - -## Protobuf Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | avg_items | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|-----------|---------|------------|-------|------------|-------|--------| -| 10000 | 1 | 587x | 856.320 us | 0.36% | 852.793 us | 0.37% | 10000 | -| 100000 | 1 | 220x | 2.285 ms | 0.22% | 2.282 ms | 0.22% | 100000 | -| 10000 | 5 | 1232x | 1.181 ms | 0.63% | 1.177 ms | 0.63% | 10000 | -| 100000 | 5 | 2272x | 4.106 ms | 0.65% | 4.103 ms | 0.65% | 100000 | -| 10000 | 20 | 235x | 2.133 ms | 0.22% | 2.129 ms | 0.22% | 10000 | -| 100000 | 20 | 65x | 7.742 ms | 0.30% | 7.738 ms | 0.30% | 100000 | - -## Protobuf Many Repeated Fields - -### [0] NVIDIA RTX 5880 Ada Generation - -| num_rows | num_rep_fields | Samples | CPU Time | Noise | GPU Time | Noise | Rows | -|----------|----------------|---------|-----------|-------|-----------|-------|--------| -| 10000 | 10 | 306x | 1.642 ms | 0.31% | 1.638 ms | 0.31% | 10000 | -| 100000 | 10 | 832x | 5.583 ms | 0.75% | 5.579 ms | 0.75% | 100000 | -| 10000 | 30 | 109x | 4.624 ms | 0.21% | 4.620 ms | 0.21% | 10000 | -| 100000 | 30 | 752x | 16.549 ms | 1.02% | 16.545 ms | 1.02% | 100000 | -| 10000 | 50 | 560x | 7.961 ms | 2.99% | 7.958 ms | 2.99% | 10000 | -| 100000 | 50 | 541x | 27.668 ms | 0.69% | 27.664 ms | 0.69% | 100000 | From 38348ccd3adc33b2f5dc1dd1755b6ffe85a3da76 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 17:56:42 +0800 Subject: [PATCH 045/107] Batched scalar extraction (2D grid kernel) Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 191 ++++++++++++++++++++++----- src/main/cpp/src/protobuf_common.cuh | 122 +++++++++++++++++ 2 files changed, 283 insertions(+), 30 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index f87b293909..a832034431 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -289,43 +289,174 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } + // Batched scalar extraction: group non-special fixed-width fields by extraction + // category and extract all fields of each category with a single 2D kernel launch. + { + struct scalar_buf_pair { + rmm::device_uvector out_bytes; + rmm::device_uvector valid; + scalar_buf_pair(rmm::cuda_stream_view s, rmm::device_async_resource_ref m) + : out_bytes(0, s, m), valid(0, s, m) {} + }; + + // Classify each scalar field + // 0=I32, 1=U32, 2=I64, 3=U64, 4=BOOL, 5=I32zz, 6=I64zz, 7=F32, 8=F64, + // 9=I32fixed, 10=I64fixed, 11=fallback + constexpr int NUM_GROUPS = 12; + constexpr int GRP_FALLBACK = 11; + std::vector group_lists[NUM_GROUPS]; + + for (int i = 0; i < num_scalar; i++) { + int si = scalar_field_indices[i]; + auto tid = schema_output_types[si].id(); + int enc = schema[si].encoding; + bool zz = (enc == spark_rapids_jni::ENC_ZIGZAG); + + // STRING, LIST, and enum-as-string go to per-field path + if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) continue; + + bool is_fixed = (enc == spark_rapids_jni::ENC_FIXED); + + // INT32 with enum validation goes to fallback + if (tid == cudf::type_id::INT32 && !zz && !is_fixed && + si < static_cast(enum_valid_values.size()) && + !enum_valid_values[si].empty()) { + group_lists[GRP_FALLBACK].push_back(i); + continue; + } + + int g = GRP_FALLBACK; + // INT32/INT64 with ENC_FIXED use fixed-width extraction (sfixed32/sfixed64) + if (tid == cudf::type_id::INT32 && is_fixed) g = 9; + else if (tid == cudf::type_id::INT64 && is_fixed) g = 10; + else if (tid == cudf::type_id::INT32 && !zz) g = 0; + else if (tid == cudf::type_id::UINT32) g = 1; + else if (tid == cudf::type_id::INT64 && !zz) g = 2; + else if (tid == cudf::type_id::UINT64) g = 3; + else if (tid == cudf::type_id::BOOL8) g = 4; + else if (tid == cudf::type_id::INT32 && zz) g = 5; + else if (tid == cudf::type_id::INT64 && zz) g = 6; + else if (tid == cudf::type_id::FLOAT32) g = 7; + else if (tid == cudf::type_id::FLOAT64) g = 8; + group_lists[g].push_back(i); + } + + // Helper: batch-extract one group using a 2D kernel, then build columns. + auto do_batch = [&](std::vector const& idxs, auto kernel_launcher) { + int nf = static_cast(idxs.size()); + if (nf == 0) return; + + std::vector> bufs; + bufs.reserve(nf); + std::vector h_descs(nf); + + for (int j = 0; j < nf; j++) { + int li = idxs[j]; + int si = scalar_field_indices[li]; + bool hd = schema[si].has_default_value; + auto& bp = *bufs.emplace_back(std::make_unique(stream, mr)); + bp.valid = rmm::device_uvector(std::max(1, num_rows), stream, mr); + // BOOL8 default comes from default_bools (converted to 0/1 int) + bool is_bool = (schema_output_types[si].id() == cudf::type_id::BOOL8); + int64_t def_i = hd ? (is_bool ? (default_bools[si] ? 1 : 0) : default_ints[si]) : 0; + h_descs[j] = {li, nullptr, bp.valid.data(), hd, + def_i, hd ? default_floats[si] : 0.0}; + } + + // kernel_launcher allocates out_bytes, sets h_descs[j].output, and launches kernel + kernel_launcher(nf, h_descs, bufs); + + // Build columns + for (int j = 0; j < nf; j++) { + int si = scalar_field_indices[idxs[j]]; + auto dt = schema_output_types[si]; + auto& bp = *bufs[j]; + auto [mask, null_count] = + protobuf_detail::make_null_mask_from_valid(bp.valid, stream, mr); + column_map[si] = std::make_unique( + dt, num_rows, bp.out_bytes.release(), std::move(mask), null_count); + } + }; + + // Varint launcher for type T with zigzag ZZ + auto varint_launch = [&](int nf, + std::vector& h_descs, + std::vector>& bufs, + size_t elem_size, auto kernel_fn) { + for (int j = 0; j < nf; j++) { + bufs[j]->out_bytes = rmm::device_uvector(num_rows * elem_size, stream, mr); + h_descs[j].output = bufs[j]->out_bytes.data(); + } + rmm::device_uvector d_descs(nf, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_descs.data(), h_descs.data(), + nf * sizeof(h_descs[0]), + cudaMemcpyHostToDevice, stream.value())); + dim3 grid((num_rows + threads - 1) / threads, nf); + kernel_fn(grid, threads, stream.value(), message_data, list_offsets, base_offset, + d_locations.data(), num_scalar, d_descs.data(), nf, num_rows, d_error.data()); + }; + + // Dispatch groups 0-8 as batched + #define LAUNCH_VARINT_BATCH(GROUP, TYPE, ZZ) \ + do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ + varint_launch(nf, hd, bf, sizeof(TYPE), \ + [](dim3 g, int t, cudaStream_t s, auto... args) { \ + protobuf_detail::extract_varint_batched_kernel<<>>( \ + args...); \ + }); \ + }) + + #define LAUNCH_FIXED_BATCH(GROUP, TYPE, WT_VAL) \ + do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ + varint_launch(nf, hd, bf, sizeof(TYPE), \ + [](dim3 g, int t, cudaStream_t s, auto... args) { \ + protobuf_detail::extract_fixed_batched_kernel<<>>( \ + args...); \ + }); \ + }) + + LAUNCH_VARINT_BATCH(0, int32_t, false); + LAUNCH_VARINT_BATCH(1, uint32_t, false); + LAUNCH_VARINT_BATCH(2, int64_t, false); + LAUNCH_VARINT_BATCH(3, uint64_t, false); + LAUNCH_VARINT_BATCH(4, uint8_t, false); + LAUNCH_VARINT_BATCH(5, int32_t, true); + LAUNCH_VARINT_BATCH(6, int64_t, true); + LAUNCH_FIXED_BATCH(7, float, WT_32BIT); + LAUNCH_FIXED_BATCH(8, double, WT_64BIT); + LAUNCH_FIXED_BATCH(9, int32_t, WT_32BIT); + LAUNCH_FIXED_BATCH(10, int64_t, WT_64BIT); + + #undef LAUNCH_VARINT_BATCH + #undef LAUNCH_FIXED_BATCH + + // Per-field fallback (INT32 with enum, etc.) + for (int i : group_lists[GRP_FALLBACK]) { + int schema_idx = scalar_field_indices[i]; + auto const dt = schema_output_types[schema_idx]; + auto const enc = schema[schema_idx].encoding; + bool has_def = schema[schema_idx].has_default_value; + TopLevelLocationProvider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + column_map[schema_idx] = extract_typed_column(dt, enc, message_data, loc_provider, + num_rows, blocks, threads, has_def, + has_def ? default_ints[schema_idx] : 0, + has_def ? default_floats[schema_idx] : 0.0, + has_def ? default_bools[schema_idx] : false, + default_strings[schema_idx], schema_idx, enum_valid_values, enum_names, + d_row_has_invalid_enum, d_error, stream, mr); + } + } + + // Per-field extraction for STRING and LIST types for (int i = 0; i < num_scalar; i++) { int schema_idx = scalar_field_indices[i]; auto const dt = schema_output_types[schema_idx]; + if (dt.id() != cudf::type_id::STRING && dt.id() != cudf::type_id::LIST) continue; auto const enc = schema[schema_idx].encoding; bool has_def = schema[schema_idx].has_default_value; switch (dt.id()) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - case cudf::type_id::FLOAT32: - case cudf::type_id::FLOAT64: { - TopLevelLocationProvider loc_provider{ - list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_typed_column(dt, - enc, - message_data, - loc_provider, - num_rows, - blocks, - threads, - has_def, - has_def ? default_ints[schema_idx] : 0, - has_def ? default_floats[schema_idx] : 0.0, - has_def ? default_bools[schema_idx] : false, - default_strings[schema_idx], - schema_idx, - enum_valid_values, - enum_names, - d_row_has_invalid_enum, - d_error, - stream, - mr); - break; - } case cudf::type_id::STRING: { if (enc == spark_rapids_jni::ENC_ENUM_STRING) { // ENUM-as-string path: diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 88ec29f0c9..b68e5df01e 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -609,6 +609,128 @@ __global__ void extract_fixed_kernel(uint8_t const* message_data, if (valid) valid[idx] = true; } +// ============================================================================ +// Batched scalar extraction — one 2D kernel for N fields of the same type +// ============================================================================ + +struct batched_scalar_desc { + int loc_field_idx; // index into the locations array (column within d_locations) + void* output; // pre-allocated output buffer (T*) + bool* valid; // pre-allocated validity buffer + bool has_default; + int64_t default_int; + double default_float; +}; + +template +__global__ void extract_varint_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (desc.has_default) { + write_value(&out[row], static_cast(desc.default_int)); + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + uint8_t const* end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + desc.valid[row] = false; + return; + } + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[row], v); + desc.valid[row] = true; +} + +template +__global__ void extract_fixed_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + if (loc.offset < 0) { + if (desc.has_default) { + out[row] = static_cast(desc.default_float); + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == WT_32BIT) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + out[row] = value; + desc.valid[row] = true; +} + +// ============================================================================ + template __global__ void extract_lengths_kernel(LocationProvider loc_provider, int total_items, From 364729960a2c11f649d70ec5698f74ea9cb39ea3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Wed, 4 Mar 2026 17:57:37 +0800 Subject: [PATCH 046/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 168 +++++++++++++++++++++-------------- 1 file changed, 100 insertions(+), 68 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a832034431..afce5a9379 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -296,13 +296,15 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_uvector out_bytes; rmm::device_uvector valid; scalar_buf_pair(rmm::cuda_stream_view s, rmm::device_async_resource_ref m) - : out_bytes(0, s, m), valid(0, s, m) {} + : out_bytes(0, s, m), valid(0, s, m) + { + } }; // Classify each scalar field // 0=I32, 1=U32, 2=I64, 3=U64, 4=BOOL, 5=I32zz, 6=I64zz, 7=F32, 8=F64, // 9=I32fixed, 10=I64fixed, 11=fallback - constexpr int NUM_GROUPS = 12; + constexpr int NUM_GROUPS = 12; constexpr int GRP_FALLBACK = 11; std::vector group_lists[NUM_GROUPS]; @@ -319,25 +321,35 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // INT32 with enum validation goes to fallback if (tid == cudf::type_id::INT32 && !zz && !is_fixed && - si < static_cast(enum_valid_values.size()) && - !enum_valid_values[si].empty()) { + si < static_cast(enum_valid_values.size()) && !enum_valid_values[si].empty()) { group_lists[GRP_FALLBACK].push_back(i); continue; } int g = GRP_FALLBACK; // INT32/INT64 with ENC_FIXED use fixed-width extraction (sfixed32/sfixed64) - if (tid == cudf::type_id::INT32 && is_fixed) g = 9; - else if (tid == cudf::type_id::INT64 && is_fixed) g = 10; - else if (tid == cudf::type_id::INT32 && !zz) g = 0; - else if (tid == cudf::type_id::UINT32) g = 1; - else if (tid == cudf::type_id::INT64 && !zz) g = 2; - else if (tid == cudf::type_id::UINT64) g = 3; - else if (tid == cudf::type_id::BOOL8) g = 4; - else if (tid == cudf::type_id::INT32 && zz) g = 5; - else if (tid == cudf::type_id::INT64 && zz) g = 6; - else if (tid == cudf::type_id::FLOAT32) g = 7; - else if (tid == cudf::type_id::FLOAT64) g = 8; + if (tid == cudf::type_id::INT32 && is_fixed) + g = 9; + else if (tid == cudf::type_id::INT64 && is_fixed) + g = 10; + else if (tid == cudf::type_id::INT32 && !zz) + g = 0; + else if (tid == cudf::type_id::UINT32) + g = 1; + else if (tid == cudf::type_id::INT64 && !zz) + g = 2; + else if (tid == cudf::type_id::UINT64) + g = 3; + else if (tid == cudf::type_id::BOOL8) + g = 4; + else if (tid == cudf::type_id::INT32 && zz) + g = 5; + else if (tid == cudf::type_id::INT64 && zz) + g = 6; + else if (tid == cudf::type_id::FLOAT32) + g = 7; + else if (tid == cudf::type_id::FLOAT64) + g = 8; group_lists[g].push_back(i); } @@ -351,16 +363,15 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& std::vector h_descs(nf); for (int j = 0; j < nf; j++) { - int li = idxs[j]; - int si = scalar_field_indices[li]; - bool hd = schema[si].has_default_value; + int li = idxs[j]; + int si = scalar_field_indices[li]; + bool hd = schema[si].has_default_value; auto& bp = *bufs.emplace_back(std::make_unique(stream, mr)); bp.valid = rmm::device_uvector(std::max(1, num_rows), stream, mr); // BOOL8 default comes from default_bools (converted to 0/1 int) - bool is_bool = (schema_output_types[si].id() == cudf::type_id::BOOL8); + bool is_bool = (schema_output_types[si].id() == cudf::type_id::BOOL8); int64_t def_i = hd ? (is_bool ? (default_bools[si] ? 1 : 0) : default_ints[si]) : 0; - h_descs[j] = {li, nullptr, bp.valid.data(), hd, - def_i, hd ? default_floats[si] : 0.0}; + h_descs[j] = {li, nullptr, bp.valid.data(), hd, def_i, hd ? default_floats[si] : 0.0}; } // kernel_launcher allocates out_bytes, sets h_descs[j].output, and launches kernel @@ -368,8 +379,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Build columns for (int j = 0; j < nf; j++) { - int si = scalar_field_indices[idxs[j]]; - auto dt = schema_output_types[si]; + int si = scalar_field_indices[idxs[j]]; + auto dt = schema_output_types[si]; auto& bp = *bufs[j]; auto [mask, null_count] = protobuf_detail::make_null_mask_from_valid(bp.valid, stream, mr); @@ -380,55 +391,64 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Varint launcher for type T with zigzag ZZ auto varint_launch = [&](int nf, - std::vector& h_descs, - std::vector>& bufs, - size_t elem_size, auto kernel_fn) { + std::vector& h_descs, + std::vector>& bufs, + size_t elem_size, + auto kernel_fn) { for (int j = 0; j < nf; j++) { bufs[j]->out_bytes = rmm::device_uvector(num_rows * elem_size, stream, mr); - h_descs[j].output = bufs[j]->out_bytes.data(); + h_descs[j].output = bufs[j]->out_bytes.data(); } rmm::device_uvector d_descs(nf, stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_descs.data(), h_descs.data(), - nf * sizeof(h_descs[0]), - cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_descs.data(), + h_descs.data(), + nf * sizeof(h_descs[0]), + cudaMemcpyHostToDevice, + stream.value())); dim3 grid((num_rows + threads - 1) / threads, nf); - kernel_fn(grid, threads, stream.value(), message_data, list_offsets, base_offset, - d_locations.data(), num_scalar, d_descs.data(), nf, num_rows, d_error.data()); + kernel_fn(grid, + threads, + stream.value(), + message_data, + list_offsets, + base_offset, + d_locations.data(), + num_scalar, + d_descs.data(), + nf, + num_rows, + d_error.data()); }; - // Dispatch groups 0-8 as batched - #define LAUNCH_VARINT_BATCH(GROUP, TYPE, ZZ) \ - do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ - varint_launch(nf, hd, bf, sizeof(TYPE), \ - [](dim3 g, int t, cudaStream_t s, auto... args) { \ - protobuf_detail::extract_varint_batched_kernel<<>>( \ - args...); \ - }); \ - }) - - #define LAUNCH_FIXED_BATCH(GROUP, TYPE, WT_VAL) \ - do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ - varint_launch(nf, hd, bf, sizeof(TYPE), \ - [](dim3 g, int t, cudaStream_t s, auto... args) { \ - protobuf_detail::extract_fixed_batched_kernel<<>>( \ - args...); \ - }); \ - }) - - LAUNCH_VARINT_BATCH(0, int32_t, false); +// Dispatch groups 0-8 as batched +#define LAUNCH_VARINT_BATCH(GROUP, TYPE, ZZ) \ + do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ + varint_launch(nf, hd, bf, sizeof(TYPE), [](dim3 g, int t, cudaStream_t s, auto... args) { \ + protobuf_detail::extract_varint_batched_kernel<<>>(args...); \ + }); \ + }) + +#define LAUNCH_FIXED_BATCH(GROUP, TYPE, WT_VAL) \ + do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ + varint_launch(nf, hd, bf, sizeof(TYPE), [](dim3 g, int t, cudaStream_t s, auto... args) { \ + protobuf_detail::extract_fixed_batched_kernel<<>>(args...); \ + }); \ + }) + + LAUNCH_VARINT_BATCH(0, int32_t, false); LAUNCH_VARINT_BATCH(1, uint32_t, false); - LAUNCH_VARINT_BATCH(2, int64_t, false); + LAUNCH_VARINT_BATCH(2, int64_t, false); LAUNCH_VARINT_BATCH(3, uint64_t, false); - LAUNCH_VARINT_BATCH(4, uint8_t, false); - LAUNCH_VARINT_BATCH(5, int32_t, true); - LAUNCH_VARINT_BATCH(6, int64_t, true); - LAUNCH_FIXED_BATCH(7, float, WT_32BIT); - LAUNCH_FIXED_BATCH(8, double, WT_64BIT); - LAUNCH_FIXED_BATCH(9, int32_t, WT_32BIT); + LAUNCH_VARINT_BATCH(4, uint8_t, false); + LAUNCH_VARINT_BATCH(5, int32_t, true); + LAUNCH_VARINT_BATCH(6, int64_t, true); + LAUNCH_FIXED_BATCH(7, float, WT_32BIT); + LAUNCH_FIXED_BATCH(8, double, WT_64BIT); + LAUNCH_FIXED_BATCH(9, int32_t, WT_32BIT); LAUNCH_FIXED_BATCH(10, int64_t, WT_64BIT); - #undef LAUNCH_VARINT_BATCH - #undef LAUNCH_FIXED_BATCH +#undef LAUNCH_VARINT_BATCH +#undef LAUNCH_FIXED_BATCH // Per-field fallback (INT32 with enum, etc.) for (int i : group_lists[GRP_FALLBACK]) { @@ -438,13 +458,25 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& bool has_def = schema[schema_idx].has_default_value; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - column_map[schema_idx] = extract_typed_column(dt, enc, message_data, loc_provider, - num_rows, blocks, threads, has_def, - has_def ? default_ints[schema_idx] : 0, - has_def ? default_floats[schema_idx] : 0.0, - has_def ? default_bools[schema_idx] : false, - default_strings[schema_idx], schema_idx, enum_valid_values, enum_names, - d_row_has_invalid_enum, d_error, stream, mr); + column_map[schema_idx] = extract_typed_column(dt, + enc, + message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + has_def ? default_ints[schema_idx] : 0, + has_def ? default_floats[schema_idx] : 0.0, + has_def ? default_bools[schema_idx] : false, + default_strings[schema_idx], + schema_idx, + enum_valid_values, + enum_names, + d_row_has_invalid_enum, + d_error, + stream, + mr); } } From 6af15772e45be9a9006da1171ddf38571c3fb1a1 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 10:55:40 +0800 Subject: [PATCH 047/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_common.cuh | 1 - src/main/cpp/src/protobuf_kernels.cu | 29 +++++---- .../rapids/jni/ProtobufSchemaDescriptor.java | 65 ++++++++++++------- 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index b68e5df01e..b6689fa2a6 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1011,7 +1011,6 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in __global__ void scan_repeated_field_occurrences_kernel(cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, int schema_idx, - int depth_level, int32_t const* output_offsets, repeated_occurrence* occurrences, int* error_flag); diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 0d70893637..844e4c749d 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -331,12 +331,12 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in int fn = tag.field_number; int wt = tag.wire_type; - // Lookup repeated field by field_number if (fn_to_rep_idx != nullptr && fn > 0 && fn < fn_to_rep_size) { int i = fn_to_rep_idx[fn]; if (i >= 0) { int schema_idx = repeated_field_indices[i]; - if (!count_repeated_element(cur, + if (schema[schema_idx].depth == depth_level && + !count_repeated_element(cur, msg_end, wt, schema[schema_idx].wire_type, @@ -411,8 +411,7 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in __global__ void scan_repeated_field_occurrences_kernel( cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, - int schema_idx, // Which field in schema we're scanning - int depth_level, + int schema_idx, int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] repeated_occurrence* occurrences, // Output: all occurrences [total_count] int* error_flag) @@ -500,13 +499,13 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co uint8_t const* cur = bytes + start; uint8_t const* msg_end = bytes + end; - // Per-field write indices, initialized from the pre-computed offsets. - // Use a fixed-size stack array to avoid dynamic allocation. - // MAX_REPEATED_SCAN_FIELDS should be generous enough for practical schemas. constexpr int MAX_STACK_FIELDS = 128; + if (num_scan_fields > MAX_STACK_FIELDS) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } int write_idx[MAX_STACK_FIELDS]; - int actual_fields = num_scan_fields < MAX_STACK_FIELDS ? num_scan_fields : MAX_STACK_FIELDS; - for (int f = 0; f < actual_fields; f++) { + for (int f = 0; f < num_scan_fields; f++) { write_idx[f] = scan_descs[f].row_offsets[row]; } @@ -535,11 +534,11 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co if (fn_to_desc_idx != nullptr && fn > 0 && fn < fn_to_desc_size) { int f = fn_to_desc_idx[fn]; - if (f >= 0) { + if (f >= 0 && f < num_scan_fields) { if (!try_scan(f)) return; } } else { - for (int f = 0; f < actual_fields; f++) { + for (int f = 0; f < num_scan_fields; f++) { if (scan_descs[f].field_number == fn) { if (!try_scan(f)) return; } @@ -768,6 +767,8 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, int fn = tag.field_number; int wt = tag.wire_type; + // After decode_tag, `cur` points past the tag bytes (at the field data). + // Both count_repeated_element and skip_field expect this post-tag position. for (int ri = 0; ri < num_repeated; ri++) { int schema_idx = repeated_indices[ri]; if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { @@ -1122,9 +1123,9 @@ __global__ void copy_enum_string_chars_kernel( int32_t src_begin = enum_name_offsets[mid]; int32_t src_end = enum_name_offsets[mid + 1]; int32_t dst_begin = output_offsets[row]; - for (int32_t i = 0; i < (src_end - src_begin); ++i) { - out_chars[dst_begin + i] = static_cast(enum_name_chars[src_begin + i]); - } + memcpy(out_chars + dst_begin, + enum_name_chars + src_begin, + static_cast(src_end - src_begin)); return; } else if (mid_val < val) { left = mid + 1; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 69f7e76b7d..eace6594ed 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -21,12 +21,11 @@ * that describe field structure, types, defaults, and enum metadata. * *

    Use this class instead of passing 15+ individual arrays through the JNI boundary. - * Validation is performed once in the constructor. + * Validation is performed once in the constructor (and again on deserialization). * - *

    The arrays are intentionally exposed as package-private (not public) to allow - * zero-copy access from {@link Protobuf} within the same package, while preventing - * external code from mutating the contents after construction. Callers outside this - * package should treat instances as opaque and immutable. + *

    All arrays are defensively copied in the constructor to guarantee immutability. + * Package-private field access from {@link Protobuf} is safe because the stored arrays + * cannot be mutated by the original caller. */ public final class ProtobufSchemaDescriptor implements java.io.Serializable { private static final long serialVersionUID = 1L; @@ -69,6 +68,44 @@ public ProtobufSchemaDescriptor( int[][] enumValidValues, byte[][][] enumNames) { + validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, + encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); + + this.fieldNumbers = fieldNumbers.clone(); + this.parentIndices = parentIndices.clone(); + this.depthLevels = depthLevels.clone(); + this.wireTypes = wireTypes.clone(); + this.outputTypeIds = outputTypeIds.clone(); + this.encodings = encodings.clone(); + this.isRepeated = isRepeated.clone(); + this.isRequired = isRequired.clone(); + this.hasDefaultValue = hasDefaultValue.clone(); + this.defaultInts = defaultInts.clone(); + this.defaultFloats = defaultFloats.clone(); + this.defaultBools = defaultBools.clone(); + this.defaultStrings = defaultStrings.clone(); + this.enumValidValues = enumValidValues.clone(); + this.enumNames = enumNames.clone(); + } + + public int numFields() { return fieldNumbers.length; } + + private void readObject(java.io.ObjectInputStream in) + throws java.io.IOException, ClassNotFoundException { + in.defaultReadObject(); + validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, + encodings, isRepeated, isRequired, hasDefaultValue, defaultInts, + defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); + } + + private static void validate( + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, boolean[] hasDefaultValue, + long[] defaultInts, double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, byte[][][] enumNames) { + if (fieldNumbers == null || parentIndices == null || depthLevels == null || wireTypes == null || outputTypeIds == null || encodings == null || isRepeated == null || isRequired == null || hasDefaultValue == null || @@ -110,23 +147,5 @@ public ProtobufSchemaDescriptor( } } } - - this.fieldNumbers = fieldNumbers; - this.parentIndices = parentIndices; - this.depthLevels = depthLevels; - this.wireTypes = wireTypes; - this.outputTypeIds = outputTypeIds; - this.encodings = encodings; - this.isRepeated = isRepeated; - this.isRequired = isRequired; - this.hasDefaultValue = hasDefaultValue; - this.defaultInts = defaultInts; - this.defaultFloats = defaultFloats; - this.defaultBools = defaultBools; - this.defaultStrings = defaultStrings; - this.enumValidValues = enumValidValues; - this.enumNames = enumNames; } - - public int numFields() { return fieldNumbers.length; } } From 0d9d105179677872406a1b00067c4c2e772b4139 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 12:41:25 +0800 Subject: [PATCH 048/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_common.cuh | 38 +++++++++++++++---- src/main/cpp/src/protobuf_kernels.cu | 5 ++- .../rapids/jni/ProtobufSchemaDescriptor.java | 34 +++++++++++++++-- 3 files changed, 66 insertions(+), 11 deletions(-) diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index b6689fa2a6..118031328b 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -219,19 +219,43 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } case WT_SGROUP: { auto const* start = cur; - // Recursively skip until the matching end-group tag. - while (cur < end) { + int depth = 1; + while (cur < end && depth > 0) { uint64_t key; int key_bytes; if (!read_varint(cur, end, key, key_bytes)) return -1; cur += key_bytes; int inner_wt = static_cast(key & 0x7); - if (inner_wt == WT_EGROUP) { return static_cast(cur - start); } - - int inner_size = get_wire_type_size(inner_wt, cur, end); - if (inner_size < 0 || cur + inner_size > end) return -1; - cur += inner_size; + if (inner_wt == WT_EGROUP) { + --depth; + if (depth == 0) { return static_cast(cur - start); } + } else if (inner_wt == WT_SGROUP) { + if (++depth > 32) return -1; + } else { + int inner_size = -1; + switch (inner_wt) { + case WT_VARINT: { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, end, dummy, vbytes)) return -1; + inner_size = vbytes; + break; + } + case WT_64BIT: inner_size = 8; break; + case WT_LEN: { + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return -1; + inner_size = len_bytes + static_cast(len); + break; + } + case WT_32BIT: inner_size = 4; break; + default: return -1; + } + if (inner_size < 0 || cur + inner_size > end) return -1; + cur += inner_size; + } } return -1; } diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 844e4c749d..0dc60bbd6c 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -381,7 +381,10 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in if (fn_to_nested_idx != nullptr && fn > 0 && fn < fn_to_nested_size) { int i = fn_to_nested_idx[fn]; if (i >= 0) { - if (!handle_nested(i)) return; + int schema_idx = nested_field_indices[i]; + if (schema[schema_idx].depth == depth_level) { + if (!handle_nested(i)) return; + } } } else { for (int i = 0; i < num_nested_fields; i++) { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index eace6594ed..7e7c27b79f 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -84,9 +84,9 @@ public ProtobufSchemaDescriptor( this.defaultInts = defaultInts.clone(); this.defaultFloats = defaultFloats.clone(); this.defaultBools = defaultBools.clone(); - this.defaultStrings = defaultStrings.clone(); - this.enumValidValues = enumValidValues.clone(); - this.enumNames = enumNames.clone(); + this.defaultStrings = deepCopy(defaultStrings); + this.enumValidValues = deepCopy(enumValidValues); + this.enumNames = deepCopy(enumNames); } public int numFields() { return fieldNumbers.length; } @@ -99,6 +99,34 @@ private void readObject(java.io.ObjectInputStream in) defaultFloats, defaultBools, defaultStrings, enumValidValues, enumNames); } + private static byte[][] deepCopy(byte[][] src) { + byte[][] dst = new byte[src.length][]; + for (int i = 0; i < src.length; i++) { + dst[i] = src[i] != null ? src[i].clone() : null; + } + return dst; + } + + private static int[][] deepCopy(int[][] src) { + int[][] dst = new int[src.length][]; + for (int i = 0; i < src.length; i++) { + dst[i] = src[i] != null ? src[i].clone() : null; + } + return dst; + } + + private static byte[][][] deepCopy(byte[][][] src) { + byte[][][] dst = new byte[src.length][][]; + for (int i = 0; i < src.length; i++) { + if (src[i] == null) continue; + dst[i] = new byte[src[i].length][]; + for (int j = 0; j < src[i].length; j++) { + dst[i][j] = src[i][j] != null ? src[i][j].clone() : null; + } + } + return dst; + } + private static void validate( int[] fieldNumbers, int[] parentIndices, int[] depthLevels, int[] wireTypes, int[] outputTypeIds, int[] encodings, From e861f40fb1602d68fa2cb9a85f9d4c3913e49b73 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 13:32:43 +0800 Subject: [PATCH 049/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 11 +++++++++-- src/main/cpp/src/protobuf.cu | 3 +++ src/main/cpp/src/protobuf_common.cuh | 19 ++++++++++--------- src/main/cpp/src/protobuf_kernels.cu | 2 +- .../com/nvidia/spark/rapids/jni/Protobuf.java | 3 +++ .../rapids/jni/ProtobufSchemaDescriptor.java | 11 +++++++++++ 6 files changed, 37 insertions(+), 12 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 36220605f7..65dcd796a6 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -176,7 +176,10 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } else { jsize len = env->GetArrayLength(byte_arr); jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); - if (bytes == nullptr) { return 0; } + if (bytes == nullptr) { + env->DeleteLocalRef(byte_arr); + return 0; + } default_string_values.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); @@ -224,7 +227,11 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } else { jsize len = env->GetArrayLength(name_bytes); jbyte* bytes = env->GetByteArrayElements(name_bytes, nullptr); - if (bytes == nullptr) { return 0; } + if (bytes == nullptr) { + env->DeleteLocalRef(name_bytes); + env->DeleteLocalRef(names_arr); + return 0; + } names_for_field.emplace_back(reinterpret_cast(bytes), reinterpret_cast(bytes) + len); env->ReleaseByteArrayElements(name_bytes, bytes, JNI_ABORT); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index afce5a9379..a30f0100e3 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -138,6 +138,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case ERR_FIXED_LEN: return "Protobuf decode error: invalid fixed-width or packed field length"; case ERR_REQUIRED: return "Protobuf decode error: missing required field"; + case ERR_SCHEMA_TOO_LARGE: + return "Protobuf decode error: schema exceeds maximum supported repeated fields per kernel " + "(128)"; default: return "Protobuf decode error: unknown error"; } }; diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 118031328b..5321208a91 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -64,15 +64,16 @@ constexpr int MAX_VARINT_BYTES = 10; constexpr int THREADS_PER_BLOCK = 256; // Error codes for kernel error reporting. -constexpr int ERR_BOUNDS = 1; -constexpr int ERR_VARINT = 2; -constexpr int ERR_FIELD_NUMBER = 3; -constexpr int ERR_WIRE_TYPE = 4; -constexpr int ERR_OVERFLOW = 5; -constexpr int ERR_FIELD_SIZE = 6; -constexpr int ERR_SKIP = 7; -constexpr int ERR_FIXED_LEN = 8; -constexpr int ERR_REQUIRED = 9; +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +constexpr int ERR_SCHEMA_TOO_LARGE = 10; // Maximum supported nesting depth for recursive struct decoding. constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 0dc60bbd6c..75163d1991 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -504,7 +504,7 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co constexpr int MAX_STACK_FIELDS = 128; if (num_scan_fields > MAX_STACK_FIELDS) { - set_error_once(error_flag, ERR_OVERFLOW); + set_error_once(error_flag, ERR_SCHEMA_TOO_LARGE); return; } int write_idx[MAX_STACK_FIELDS]; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 43f2a3eb01..40c7675278 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -73,6 +73,9 @@ public class Protobuf { public static ColumnVector decodeToStruct(ColumnView binaryInput, ProtobufSchemaDescriptor schema, boolean failOnErrors) { + if (schema == null) { + throw new IllegalArgumentException("schema must not be null"); + } long handle = decodeToStruct(binaryInput.getNativeView(), schema.fieldNumbers, schema.parentIndices, schema.depthLevels, schema.wireTypes, schema.outputTypeIds, schema.encodings, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 7e7c27b79f..141eff53f4 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -159,6 +159,12 @@ private static void validate( "Invalid field number at index " + i + ": " + fieldNumbers[i] + " (must be 1-" + MAX_FIELD_NUMBER + ")"); } + int wt = wireTypes[i]; + if (wt != 0 && wt != 1 && wt != 2 && wt != 5) { + throw new IllegalArgumentException( + "Invalid wire type at index " + i + ": " + wt + + " (must be one of {0, 1, 2, 5})"); + } int enc = encodings[i]; if (enc < Protobuf.ENC_DEFAULT || enc > Protobuf.ENC_ENUM_STRING) { throw new IllegalArgumentException( @@ -173,6 +179,11 @@ private static void validate( "(binary search requires it), but found " + ev[j - 1] + " before " + ev[j]); } } + if (enumNames[i] != null && enumNames[i].length != ev.length) { + throw new IllegalArgumentException( + "enumNames[" + i + "].length (" + enumNames[i].length + ") must equal " + + "enumValidValues[" + i + "].length (" + ev.length + ")"); + } } } } From 2ff75a94c6273f6f90a9cc4ce91999d89cbc68f2 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 14:12:58 +0800 Subject: [PATCH 050/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 2 +- src/main/cpp/src/ProtobufJni.cpp | 5 ++++- src/main/cpp/src/protobuf.cu | 12 ++++++++++-- src/main/cpp/src/protobuf_common.cuh | 1 + 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 4bf4bd622b..2f48a431dd 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -47,7 +47,7 @@ void encode_varint(std::vector& buf, uint64_t value) void encode_tag(std::vector& buf, int field_number, int wire_type) { - encode_varint(buf, static_cast(field_number << 3 | wire_type)); + encode_varint(buf, (static_cast(field_number) << 3) | static_cast(wire_type)); } void encode_varint_field(std::vector& buf, int field_number, int64_t value) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 65dcd796a6..5882bfbcb3 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -198,7 +198,10 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } else { jsize len = env->GetArrayLength(int_arr); jint* ints = env->GetIntArrayElements(int_arr, nullptr); - if (ints == nullptr) { return 0; } + if (ints == nullptr) { + env->DeleteLocalRef(int_arr); + return 0; + } enum_values.emplace_back(ints, ints + len); env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); env->DeleteLocalRef(int_arr); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index a30f0100e3..fc6c6d7834 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -528,11 +528,19 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr); } else { // Missing enum metadata for enum-as-string field; mark as decode error. - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + { + int err_val = ERR_BOUNDS; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); } } else { - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + { + int err_val = ERR_BOUNDS; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); } } else { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 5321208a91..ba4be39884 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -248,6 +248,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con uint64_t len; int len_bytes; if (!read_varint(cur, end, len, len_bytes)) return -1; + if (len > static_cast(INT_MAX - len_bytes)) return -1; inner_size = len_bytes + static_cast(len); break; } From f6ebffe6e555f071409ea959ac8a221976be61b6 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 16:17:17 +0800 Subject: [PATCH 051/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 1 + src/main/cpp/src/protobuf_builders.cu | 29 +++++++++++++++++++++++++++ src/main/cpp/src/protobuf_common.cuh | 1 + src/main/cpp/src/protobuf_kernels.cu | 6 +++++- 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index fc6c6d7834..36811b4135 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -842,6 +842,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_rows, enum_valid_values[schema_idx], enum_names[schema_idx], + d_row_has_invalid_enum, d_error, stream, mr); diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 3b1c43d369..bd5ed6a2e8 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -344,6 +344,7 @@ std::unique_ptr build_repeated_enum_string_column( int num_rows, std::vector const& valid_enums, std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -398,6 +399,34 @@ std::unique_ptr build_repeated_enum_string_column( rmm::device_uvector elem_valid(total_count, stream, mr); thrust::fill(rmm::exec_policy(stream), elem_valid.data(), elem_valid.end(), true); + // 3b. Validate enum values — mark invalid as false in elem_valid + rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); + thrust::fill(rmm::exec_policy(stream), + d_elem_has_invalid_enum.begin(), + d_elem_has_invalid_enum.end(), + false); + validate_enum_values_kernel<<>>( + enum_ints.data(), + elem_valid.data(), + d_elem_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + total_count); + + // 3c. Propagate per-element invalid enum flags to per-row flags for struct null mask. + // Spark CPU nullifies the entire struct row when any repeated enum element is invalid. + if (d_row_has_invalid_enum.size() > 0 && total_count > 0) { + auto const* occs = d_occurrences.data(); + auto const* elem_invalid = d_elem_has_invalid_enum.data(); + auto* row_invalid = d_row_has_invalid_enum.data(); + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + [occs, elem_invalid, row_invalid] __device__(int idx) { + if (elem_invalid[idx]) { row_invalid[occs[idx].row_idx] = true; } + }); + } + // 4. Compute per-element string lengths rmm::device_uvector elem_lengths(total_count, stream, mr); compute_enum_string_lengths_kernel<<>>( diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index ba4be39884..f11460556f 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1198,6 +1198,7 @@ std::unique_ptr build_repeated_enum_string_column( int num_rows, std::vector const& valid_enums, std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 75163d1991..88f223784e 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -707,7 +707,11 @@ __global__ void scan_repeated_message_children_kernel( if (wt == WT_VARINT) { uint64_t dummy; int vbytes; - if (read_varint(cur, msg_end, dummy, vbytes)) { data_length = vbytes; } + if (!read_varint(cur, msg_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + data_length = vbytes; } else if (wt == WT_32BIT) { data_length = 4; } else if (wt == WT_64BIT) { From 7a312d47af432b24d74c8340e2857e6dc6c55ee4 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 17:07:42 +0800 Subject: [PATCH 052/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 5 ++++- src/main/cpp/src/protobuf_common.cuh | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 5882bfbcb3..f8dce47a91 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -224,7 +224,10 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, names_for_field.reserve(num_names); for (jsize j = 0; j < num_names; ++j) { jbyteArray name_bytes = static_cast(env->GetObjectArrayElement(names_arr, j)); - if (env->ExceptionCheck()) { return 0; } + if (env->ExceptionCheck()) { + env->DeleteLocalRef(names_arr); + return 0; + } if (name_bytes == nullptr) { names_for_field.emplace_back(); } else { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index f11460556f..d1845d423b 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -722,7 +722,11 @@ __global__ void extract_fixed_batched_kernel(uint8_t const* message_data, if (loc.offset < 0) { if (desc.has_default) { - out[row] = static_cast(desc.default_float); + if constexpr (std::is_integral_v) { + out[row] = static_cast(desc.default_int); + } else { + out[row] = static_cast(desc.default_float); + } desc.valid[row] = true; } else { desc.valid[row] = false; From 4fa362cd11fa8e7939995b8f6b1bf6022b0d2929 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 17:25:57 +0800 Subject: [PATCH 053/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 6 ++++-- src/main/cpp/src/protobuf_builders.cu | 7 ++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 36811b4135..31516e26a3 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -547,7 +547,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Regular protobuf STRING (length-delimited) bool has_def_str = has_def; auto const& def_str = default_strings[schema_idx]; - TopLevelLocationProvider len_provider{nullptr, 0, d_locations.data(), i, num_scalar}; + TopLevelLocationProvider len_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; TopLevelLocationProvider copy_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_str] __device__( @@ -572,7 +573,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; auto const& def_bytes = default_strings[schema_idx]; - TopLevelLocationProvider len_provider{nullptr, 0, d_locations.data(), i, num_scalar}; + TopLevelLocationProvider len_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; TopLevelLocationProvider copy_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__( diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index bd5ed6a2e8..e5ed375ec9 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -457,9 +457,10 @@ std::unique_ptr build_repeated_enum_string_column( total_count); } - // 7. Assemble strings child column - auto child_col = cudf::make_strings_column( - total_count, std::move(str_offs_col), chars.release(), 0, rmm::device_buffer{}); + // 7. Assemble strings child column with null mask from elem_valid + auto [child_mask, child_null_count] = make_null_mask_from_valid(elem_valid, stream, mr); + auto child_col = cudf::make_strings_column( + total_count, std::move(str_offs_col), chars.release(), child_null_count, std::move(child_mask)); // 8. Build LIST column with list offsets from per-row counts rmm::device_uvector lo(num_rows + 1, stream, mr); From 54fda9451118f49bf09d6c12a7278dda62b5311e Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 17:45:56 +0800 Subject: [PATCH 054/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 2 +- src/main/cpp/src/protobuf_kernels.cu | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index e5ed375ec9..13080b7278 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -545,7 +545,7 @@ std::unique_ptr build_repeated_string_column( rmm::device_uvector str_lengths(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; auto const blocks = (total_count + threads - 1) / threads; - RepeatedLocationProvider loc_provider{nullptr, 0, d_occurrences.data()}; + RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; extract_lengths_kernel <<>>(loc_provider, total_count, str_lengths.data()); diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 88f223784e..22b993465f 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -698,7 +698,11 @@ __global__ void scan_repeated_message_children_kernel( set_error_once(error_flag, ERR_VARINT); return; } - // Store offset (after length prefix) and length + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, static_cast(len)}; } else { From aeaf50cade2d5a1f9edacb6aa34758d9de06e082 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 20:43:29 +0800 Subject: [PATCH 055/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 29 +++++++++++++-------------- src/main/cpp/src/protobuf_builders.cu | 16 +++++++++++---- src/main/cpp/src/protobuf_common.cuh | 21 ++++++++++--------- src/main/cpp/src/protobuf_kernels.cu | 5 +++++ 4 files changed, 42 insertions(+), 29 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 31516e26a3..27d112865d 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -141,6 +141,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case ERR_SCHEMA_TOO_LARGE: return "Protobuf decode error: schema exceeds maximum supported repeated fields per kernel " "(128)"; + case ERR_MISSING_ENUM_META: + return "Protobuf decode error: missing or mismatched enum metadata for enum-as-string " + "field"; default: return "Protobuf decode error: unknown error"; } }; @@ -529,7 +532,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { // Missing enum metadata for enum-as-string field; mark as decode error. { - int err_val = ERR_BOUNDS; + int err_val = ERR_MISSING_ENUM_META; CUDF_CUDA_TRY(cudaMemcpyAsync( d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); } @@ -537,7 +540,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } else { { - int err_val = ERR_BOUNDS; + int err_val = ERR_MISSING_ENUM_META; CUDF_CUDA_TRY(cudaMemcpyAsync( d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); } @@ -849,19 +852,15 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream, mr); } else { - column_map[schema_idx] = build_repeated_string_column(binary_input, - message_data, - list_offsets, - base_offset, - h_device_schema[schema_idx], - d_field_counts, - d_occurrences, - total_count, - num_rows, - false, - d_error, - stream, - mr); + // Missing/mismatched enum metadata for repeated enum-as-string field. + // Set error and produce null column, consistent with the scalar path. + { + int err_val = ERR_MISSING_ENUM_META; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } + column_map[schema_idx] = + make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); } break; } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 13080b7278..e4bb8f3b73 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -1175,8 +1175,12 @@ std::unique_ptr build_nested_struct_column( } else { bool has_def_str = has_def; auto const& def_str = default_strings[child_schema_idx]; - NestedLocationProvider len_provider{ - nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; + NestedLocationProvider len_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; NestedLocationProvider copy_provider{list_offsets, base_offset, d_parent_locs.data(), @@ -1209,8 +1213,12 @@ std::unique_ptr build_nested_struct_column( // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; auto const& def_bytes = default_strings[child_schema_idx]; - NestedLocationProvider len_provider{ - nullptr, 0, d_parent_locs.data(), d_child_locations.data(), ci, num_child_fields}; + NestedLocationProvider len_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; NestedLocationProvider copy_provider{list_offsets, base_offset, d_parent_locs.data(), diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index d1845d423b..09affcfdfe 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -64,16 +64,17 @@ constexpr int MAX_VARINT_BYTES = 10; constexpr int THREADS_PER_BLOCK = 256; // Error codes for kernel error reporting. -constexpr int ERR_BOUNDS = 1; -constexpr int ERR_VARINT = 2; -constexpr int ERR_FIELD_NUMBER = 3; -constexpr int ERR_WIRE_TYPE = 4; -constexpr int ERR_OVERFLOW = 5; -constexpr int ERR_FIELD_SIZE = 6; -constexpr int ERR_SKIP = 7; -constexpr int ERR_FIXED_LEN = 8; -constexpr int ERR_REQUIRED = 9; -constexpr int ERR_SCHEMA_TOO_LARGE = 10; +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +constexpr int ERR_SCHEMA_TOO_LARGE = 10; +constexpr int ERR_MISSING_ENUM_META = 11; // Maximum supported nesting depth for recursive struct decoding. constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 22b993465f..c1b594ed32 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -373,6 +373,11 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in set_error_once(error_flag, ERR_VARINT); return false; } + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; return true; From d50f089a03f0a2b80a87d0f0092b3ce827f3698a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 21:38:50 +0800 Subject: [PATCH 056/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 63 +++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 27d112865d..0b20b6bb9d 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -831,36 +831,47 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& break; case cudf::type_id::STRING: { auto enc = schema[schema_idx].encoding; - if (enc == spark_rapids_jni::ENC_ENUM_STRING && - schema_idx < static_cast(enum_valid_values.size()) && - schema_idx < static_cast(enum_names.size()) && - !enum_valid_values[schema_idx].empty() && - enum_valid_values[schema_idx].size() == enum_names[schema_idx].size()) { - column_map[schema_idx] = - build_repeated_enum_string_column(binary_input, - message_data, - list_offsets, - base_offset, - d_field_counts, - d_occurrences, - total_count, - num_rows, - enum_valid_values[schema_idx], - enum_names[schema_idx], - d_row_has_invalid_enum, - d_error, - stream, - mr); - } else { - // Missing/mismatched enum metadata for repeated enum-as-string field. - // Set error and produce null column, consistent with the scalar path. - { + if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + if (schema_idx < static_cast(enum_valid_values.size()) && + schema_idx < static_cast(enum_names.size()) && + !enum_valid_values[schema_idx].empty() && + enum_valid_values[schema_idx].size() == enum_names[schema_idx].size()) { + column_map[schema_idx] = + build_repeated_enum_string_column(binary_input, + message_data, + list_offsets, + base_offset, + d_field_counts, + d_occurrences, + total_count, + num_rows, + enum_valid_values[schema_idx], + enum_names[schema_idx], + d_row_has_invalid_enum, + d_error, + stream, + mr); + } else { int err_val = ERR_MISSING_ENUM_META; CUDF_CUDA_TRY(cudaMemcpyAsync( d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + column_map[schema_idx] = + make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); } - column_map[schema_idx] = - make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); + } else { + column_map[schema_idx] = build_repeated_string_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + false, + d_error, + stream, + mr); } break; } From 63bc13a50c8b01163c7d890902322ee67afeb1f0 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 22:04:46 +0800 Subject: [PATCH 057/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 2 -- src/main/cpp/src/protobuf_builders.cu | 20 +++++++++---------- src/main/cpp/src/protobuf_common.cuh | 2 -- src/main/cpp/src/protobuf_kernels.cu | 2 -- .../com/nvidia/spark/rapids/jni/Protobuf.java | 3 +++ .../rapids/jni/ProtobufSchemaDescriptor.java | 12 ++++++++--- 6 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 0b20b6bb9d..b9d1260657 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -698,8 +698,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& scan_all_repeated_occurrences_kernel<<>>( *d_in, - d_schema.data(), - 0, d_scan_descs.data(), static_cast(h_scan_descs.size()), d_error.data(), diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index e4bb8f3b73..c90c1cc502 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -46,7 +46,7 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( } auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1u) / threads; rmm::device_uvector d_lengths(total_count, stream, mr); thrust::transform( @@ -254,7 +254,7 @@ std::unique_ptr build_enum_string_column( rmm::device_async_resource_ref mr) { auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1) / threads); + auto const blocks = static_cast((num_rows + threads - 1u) / threads); rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), @@ -350,7 +350,7 @@ std::unique_ptr build_repeated_enum_string_column( rmm::device_async_resource_ref mr) { auto const rep_blocks = - static_cast((total_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); + static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); @@ -544,7 +544,7 @@ std::unique_ptr build_repeated_string_column( // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1u) / threads; RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; extract_lengths_kernel <<>>(loc_provider, total_count, str_lengths.data()); @@ -753,7 +753,7 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector d_msg_row_offsets_size(total_count, stream, mr); { auto const occ_threads = THREADS_PER_BLOCK; - auto const occ_blocks = (total_count + occ_threads - 1) / occ_threads; + auto const occ_blocks = (total_count + occ_threads - 1u) / occ_threads; compute_msg_locations_from_occurrences_kernel<<>>( d_occurrences.data(), list_offsets, @@ -774,7 +774,7 @@ std::unique_ptr build_repeated_struct_column( auto& d_error = d_error_top; auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = (total_count + threads - 1u) / threads; // Use a custom kernel to scan child fields within message occurrences // This is similar to scan_nested_message_fields_kernel but operates on occurrences @@ -1032,7 +1032,7 @@ std::unique_ptr build_nested_struct_column( } auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1) / threads); + auto const blocks = static_cast((num_rows + threads - 1u) / threads); int num_child_fields = static_cast(child_field_indices.size()); std::vector h_child_field_descs(num_child_fields); @@ -1325,7 +1325,7 @@ std::unique_ptr build_repeated_child_list_column( int depth) { auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_parent_rows + threads - 1) / threads); + auto const blocks = static_cast((num_parent_rows + threads - 1u) / threads); auto elem_type_id = schema[child_schema_idx].output_type; rmm::device_uvector d_rep_info(num_parent_rows, stream, mr); @@ -1417,7 +1417,7 @@ std::unique_ptr build_repeated_child_list_column( std::unique_ptr child_values; auto const rep_blocks = - static_cast((total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK); + static_cast((total_rep_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); NestedRepeatedLocationProvider nr_loc{row_offsets, base_offset, parent_locs, d_rep_occs.data()}; if (elem_type_id == cudf::type_id::BOOL8 || elem_type_id == cudf::type_id::INT32 || @@ -1470,7 +1470,7 @@ std::unique_ptr build_repeated_child_list_column( } else { rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); - auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK; compute_virtual_parents_for_nested_repeated_kernel<< Date: Thu, 5 Mar 2026 22:24:28 +0800 Subject: [PATCH 058/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 5 +- src/main/cpp/src/protobuf_common.cuh | 7 --- src/main/cpp/src/protobuf_kernels.cu | 70 +--------------------------- 3 files changed, 5 insertions(+), 77 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index b9d1260657..dee8656e44 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -333,11 +333,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } int g = GRP_FALLBACK; - // INT32/INT64 with ENC_FIXED use fixed-width extraction (sfixed32/sfixed64) if (tid == cudf::type_id::INT32 && is_fixed) g = 9; else if (tid == cudf::type_id::INT64 && is_fixed) g = 10; + else if (tid == cudf::type_id::UINT32 && is_fixed) + g = 9; + else if (tid == cudf::type_id::UINT64 && is_fixed) + g = 10; else if (tid == cudf::type_id::INT32 && !zz) g = 0; else if (tid == cudf::type_id::UINT32) diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 6ab5975608..284f1b8fe1 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1039,13 +1039,6 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in int const* fn_to_nested_idx = nullptr, int fn_to_nested_size = 0); -__global__ void scan_repeated_field_occurrences_kernel(cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int schema_idx, - int32_t const* output_offsets, - repeated_occurrence* occurrences, - int* error_flag); - __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, repeated_field_scan_desc const* scan_descs, int num_scan_fields, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 0b24b36996..470561fb35 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -410,77 +410,9 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in } } -/** - * Scan and record all occurrences of repeated fields. - * Called after count_repeated_fields_kernel to fill in actual locations. - * - * @note Time complexity: O(message_length * num_repeated_fields) per row. - */ -__global__ void scan_repeated_field_occurrences_kernel( - cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int schema_idx, - int32_t const* output_offsets, // Pre-computed offsets from prefix sum [num_rows + 1] - repeated_occurrence* occurrences, // Output: all occurrences [total_count] - int* error_flag) -{ - auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; - - if (in.nullable() && in.is_null(row)) { return; } - - auto const base = in.offset_at(0); - auto const child = in.get_sliced_child(); - auto const* bytes = reinterpret_cast(child.data()); - int32_t start = in.offset_at(row) - base; - int32_t end = in.offset_at(row + 1) - base; - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } - - uint8_t const* cur = bytes + start; - uint8_t const* msg_end = bytes + end; - - int target_fn = schema[schema_idx].field_number; - int target_wt = schema[schema_idx].wire_type; - int write_idx = output_offsets[row]; - - while (cur < msg_end) { - proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } - int fn = tag.field_number; - int wt = tag.wire_type; - - if (fn == target_fn) { - bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); - if (is_packed || wt == target_wt) { - if (!scan_repeated_element(cur, - msg_end, - bytes + start, - wt, - target_wt, - static_cast(row), - occurrences, - write_idx, - error_flag)) { - return; - } - } - } - - // Skip to next field - uint8_t const* next; - if (!skip_field(cur, msg_end, wt, next)) { - set_error_once(error_flag, ERR_SKIP); - return; - } - cur = next; - } -} - /** * Combined occurrence scan: scans each message ONCE and writes occurrences for ALL - * repeated fields simultaneously. Replaces N separate scan_repeated_field_occurrences_kernel - * launches with a single kernel, eliminating N-1 redundant full-message scans. + * repeated fields simultaneously, scanning each message only once. */ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, repeated_field_scan_desc const* scan_descs, From 8ae954702bb8667b3b1b3e1c9902b1f57dc596c3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 22:48:31 +0800 Subject: [PATCH 059/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 2 +- src/main/cpp/src/protobuf_kernels.cu | 8 ++++---- .../nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java | 7 ++++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index dee8656e44..c2bb1de269 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -157,7 +157,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1) / threads); + auto const blocks = static_cast((num_rows + threads - 1u) / threads); // Allocate for counting repeated fields rmm::device_uvector d_repeated_info( diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 470561fb35..dfe73e87fa 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -142,11 +142,11 @@ __device__ bool count_repeated_element(uint8_t const* cur, return false; } uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { + if (packed_len > static_cast(msg_end - packed_start)) { set_error_once(error_flag, ERR_OVERFLOW); return false; } + uint8_t const* packed_end = packed_start + packed_len; int count = 0; if (expected_wt == WT_VARINT) { @@ -220,11 +220,11 @@ __device__ bool scan_repeated_element(uint8_t const* cur, return false; } uint8_t const* packed_start = cur + len_bytes; - uint8_t const* packed_end = packed_start + packed_len; - if (packed_end > msg_end) { + if (packed_len > static_cast(msg_end - packed_start)) { set_error_once(error_flag, ERR_OVERFLOW); return false; } + uint8_t const* packed_end = packed_start + packed_len; if (expected_wt == WT_VARINT) { uint8_t const* p = packed_start; diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index e37bd12564..3c35784d63 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -179,10 +179,11 @@ private static void validate( if (enumValidValues[i] != null) { int[] ev = enumValidValues[i]; for (int j = 1; j < ev.length; j++) { - if (ev[j] < ev[j - 1]) { + if (ev[j] <= ev[j - 1]) { throw new IllegalArgumentException( - "enumValidValues[" + i + "] must be sorted in ascending order " + - "(binary search requires it), but found " + ev[j - 1] + " before " + ev[j]); + "enumValidValues[" + i + "] must be strictly sorted in ascending order " + + "(binary search requires unique values), but found " + ev[j - 1] + + " followed by " + ev[j]); } } if (enumNames[i] != null && enumNames[i].length != ev.length) { From 756010df877f4f957eed3464a2f8794aeda517e6 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 23:10:30 +0800 Subject: [PATCH 060/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 11 +++++------ src/main/cpp/src/protobuf_common.cuh | 7 ++++--- src/main/cpp/src/protobuf_kernels.cu | 8 ++++++++ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index c90c1cc502..f2ff0a89fb 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -354,9 +354,10 @@ std::unique_ptr build_repeated_enum_string_column( // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); + rmm::device_uvector elem_valid(total_count, stream, mr); RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; extract_varint_kernel<<>>( - message_data, rep_loc, total_count, enum_ints.data(), nullptr, d_error.data()); + message_data, rep_loc, total_count, enum_ints.data(), elem_valid.data(), d_error.data()); // 2. Build device-side enum lookup tables rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); @@ -395,11 +396,9 @@ std::unique_ptr build_repeated_enum_string_column( stream.value())); } - // 3. Per-element validity - rmm::device_uvector elem_valid(total_count, stream, mr); - thrust::fill(rmm::exec_policy(stream), elem_valid.data(), elem_valid.end(), true); - - // 3b. Validate enum values — mark invalid as false in elem_valid + // 3. Validate enum values — mark invalid as false in elem_valid + // (elem_valid was already populated by extract_varint_kernel: true for success, false for + // failure) rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); thrust::fill(rmm::exec_policy(stream), d_elem_has_invalid_enum.begin(), diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 284f1b8fe1..a7bea4bb52 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -344,12 +344,13 @@ __device__ inline bool decode_tag(uint8_t const*& cur, } cur += key_bytes; - tag.field_number = static_cast(key >> 3); - tag.wire_type = static_cast(key & 0x7); - if (tag.field_number == 0) { + uint64_t fn = key >> 3; + if (fn == 0 || fn > static_cast(INT_MAX)) { set_error_once(error_flag, ERR_FIELD_NUMBER); return false; } + tag.field_number = static_cast(fn); + tag.wire_type = static_cast(key & 0x7); return true; } diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index dfe73e87fa..d94809c317 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -652,8 +652,16 @@ __global__ void scan_repeated_message_children_kernel( } data_length = vbytes; } else if (wt == WT_32BIT) { + if (msg_end - cur < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } data_length = 4; } else if (wt == WT_64BIT) { + if (msg_end - cur < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } data_length = 8; } child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; From 64de422f528fd81e7c4423b84d19394e2536d3e0 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 5 Mar 2026 23:37:05 +0800 Subject: [PATCH 061/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 19 +++++++++++++------ src/main/cpp/src/protobuf_kernels.cu | 8 +++++++- .../rapids/jni/ProtobufSchemaDescriptor.java | 4 ++++ 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index f2ff0a89fb..5cb3df1b27 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -418,12 +418,19 @@ std::unique_ptr build_repeated_enum_string_column( auto const* occs = d_occurrences.data(); auto const* elem_invalid = d_elem_has_invalid_enum.data(); auto* row_invalid = d_row_has_invalid_enum.data(); - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - [occs, elem_invalid, row_invalid] __device__(int idx) { - if (elem_invalid[idx]) { row_invalid[occs[idx].row_idx] = true; } - }); + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + [occs, elem_invalid, row_invalid] __device__(int idx) { + if (elem_invalid[idx]) { + auto* addr = reinterpret_cast( + reinterpret_cast(row_invalid + occs[idx].row_idx) & ~uintptr_t{3}); + unsigned int byte_offset = + (reinterpret_cast(row_invalid + occs[idx].row_idx) & 3u) * 8u; + atomicOr(addr, 1u << byte_offset); + } + }); } // 4. Compute per-element string lengths diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index d94809c317..36574a5acd 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -683,7 +683,11 @@ __global__ void scan_repeated_message_children_kernel( /** * Count repeated field occurrences within nested messages. * Similar to count_repeated_fields_kernel but operates on nested message locations. - + * + * Note: unlike the top-level count_repeated_fields_kernel, this kernel does not perform + * a depth-level check because it operates within a specific parent message context where + * the depth is implicitly fixed. Callers must pre-filter repeated_indices to include only + * fields at the expected child depth. */ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, cudf::size_type const* row_offsets, @@ -750,6 +754,8 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, * Scan for repeated field occurrences of a single repeated field within nested * messages. Must be launched once per repeated field — the caller passes * exactly one schema index via repeated_indices[0]. + * + * Note: no depth-level check is performed; see count_repeated_in_nested_kernel comment. */ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, cudf::size_type const* row_offsets, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 3c35784d63..cb0606e0ba 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -191,6 +191,10 @@ private static void validate( "enumNames[" + i + "].length (" + enumNames[i].length + ") must equal " + "enumValidValues[" + i + "].length (" + ev.length + ")"); } + } else if (enumNames[i] != null) { + throw new IllegalArgumentException( + "enumNames[" + i + "] is non-null but enumValidValues[" + i + "] is null; " + + "both must be provided together for enum-as-string fields"); } } } From aae6bca543aa2bd6084bfb9058236bc0e388b949 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Mar 2026 10:08:53 +0800 Subject: [PATCH 062/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 5 ++- src/main/cpp/src/protobuf_builders.cu | 35 ++++++++++++------- src/main/cpp/src/protobuf_common.cuh | 5 +++ src/main/cpp/src/protobuf_kernels.cu | 30 +++++++++++----- .../rapids/jni/ProtobufSchemaDescriptor.java | 18 ++++++++++ 5 files changed, 70 insertions(+), 23 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index c2bb1de269..2ce4fadcd9 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -80,7 +80,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Extract shared input data pointers (used by scalar, repeated, and nested sections) cudf::lists_column_view const in_list_view(binary_input); auto const* message_data = reinterpret_cast(in_list_view.child().data()); - auto const* list_offsets = in_list_view.offsets().data(); + auto const message_data_size = static_cast(in_list_view.child().size()); + auto const* list_offsets = in_list_view.offsets().data(); cudf::size_type base_offset = 0; CUDF_CUDA_TRY(cudaMemcpyAsync( @@ -902,6 +903,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { column_map[schema_idx] = build_repeated_struct_column(binary_input, message_data, + message_data_size, list_offsets, base_offset, h_device_schema[schema_idx], @@ -990,6 +992,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_nested_locations.data(), ni, num_nested, d_parent_locs.data(), num_rows); column_map[parent_schema_idx] = build_nested_struct_column(message_data, + message_data_size, list_offsets, base_offset, d_parent_locs, diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 5cb3df1b27..620d5fa58a 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -418,19 +418,16 @@ std::unique_ptr build_repeated_enum_string_column( auto const* occs = d_occurrences.data(); auto const* elem_invalid = d_elem_has_invalid_enum.data(); auto* row_invalid = d_row_has_invalid_enum.data(); - thrust::for_each( - rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - [occs, elem_invalid, row_invalid] __device__(int idx) { - if (elem_invalid[idx]) { - auto* addr = reinterpret_cast( - reinterpret_cast(row_invalid + occs[idx].row_idx) & ~uintptr_t{3}); - unsigned int byte_offset = - (reinterpret_cast(row_invalid + occs[idx].row_idx) & 3u) * 8u; - atomicOr(addr, 1u << byte_offset); - } - }); + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + [occs, elem_invalid, row_invalid] __device__(int idx) { + if (elem_invalid[idx]) { + // Safe: all threads write the same value (true). On sm_70+ byte stores + // are independently addressable and do not tear neighboring bytes. + row_invalid[occs[idx].row_idx] = true; + } + }); } // 4. Compute per-element string lengths @@ -611,6 +608,7 @@ std::unique_ptr build_repeated_string_column( // but the latter's STRUCT-child case needs to call it. std::unique_ptr build_nested_struct_column( uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* list_offsets, cudf::size_type base_offset, rmm::device_uvector const& d_parent_locs, @@ -636,6 +634,7 @@ std::unique_ptr build_nested_struct_column( // need to call it. std::unique_ptr build_repeated_child_list_column( uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -659,6 +658,7 @@ std::unique_ptr build_repeated_child_list_column( std::unique_ptr build_repeated_struct_column( cudf::column_view const& binary_input, uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* list_offsets, cudf::size_type base_offset, device_nested_field_descriptor const& field_desc, @@ -786,6 +786,7 @@ std::unique_ptr build_repeated_struct_column( // This is similar to scan_nested_message_fields_kernel but operates on occurrences scan_repeated_message_children_kernel<<>>( message_data, + message_data_size, d_msg_row_offsets.data(), d_msg_locs.data(), total_count, @@ -811,6 +812,7 @@ std::unique_ptr build_repeated_struct_column( if (child_is_repeated) { struct_children.push_back(build_repeated_child_list_column(message_data, + message_data_size, d_msg_row_offsets_size.data(), 0, d_msg_locs.data(), @@ -938,6 +940,7 @@ std::unique_ptr build_repeated_struct_column( } struct_children.push_back(build_nested_struct_column(message_data, + message_data_size, d_nested_row_offsets.data(), base_offset, d_nested_locs, @@ -994,6 +997,7 @@ std::unique_ptr build_repeated_struct_column( std::unique_ptr build_nested_struct_column( uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* list_offsets, cudf::size_type base_offset, rmm::device_uvector const& d_parent_locs, @@ -1059,6 +1063,7 @@ std::unique_ptr build_nested_struct_column( static_cast(num_rows) * num_child_fields, stream, mr); scan_nested_message_fields_kernel<<>>( message_data, + message_data_size, list_offsets, base_offset, d_parent_locs.data(), @@ -1078,6 +1083,7 @@ std::unique_ptr build_nested_struct_column( if (is_repeated) { struct_children.push_back(build_repeated_child_list_column(message_data, + message_data_size, list_offsets, base_offset, d_parent_locs.data(), @@ -1267,6 +1273,7 @@ std::unique_ptr build_nested_struct_column( d_gc_parent.data(), num_rows); struct_children.push_back(build_nested_struct_column(message_data, + message_data_size, list_offsets, base_offset, d_gc_parent, @@ -1310,6 +1317,7 @@ std::unique_ptr build_nested_struct_column( */ std::unique_ptr build_repeated_child_list_column( uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -1489,6 +1497,7 @@ std::unique_ptr build_repeated_child_list_column( total_rep_count); child_values = build_nested_struct_column(message_data, + message_data_size, d_virtual_row_offsets.data(), base_offset, d_virtual_parent_locs, diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index a7bea4bb52..b85110609f 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1048,6 +1048,7 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co int fn_to_desc_size = 0); __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* parent_row_offsets, cudf::size_type parent_base_offset, field_location const* parent_locations, @@ -1058,6 +1059,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, int* error_flag); __global__ void scan_repeated_message_children_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, int32_t const* msg_row_offsets, field_location const* msg_locs, int num_occurrences, @@ -1217,6 +1219,7 @@ std::unique_ptr build_repeated_string_column( std::unique_ptr build_nested_struct_column( uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* list_offsets, cudf::size_type base_offset, rmm::device_uvector const& d_parent_locs, @@ -1239,6 +1242,7 @@ std::unique_ptr build_nested_struct_column( std::unique_ptr build_repeated_child_list_column( uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -1262,6 +1266,7 @@ std::unique_ptr build_repeated_child_list_column( std::unique_ptr build_repeated_struct_column( cudf::column_view const& binary_input, uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* list_offsets, cudf::size_type base_offset, device_nested_field_descriptor const& field_desc, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 36574a5acd..1d78df8c67 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -502,6 +502,7 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co * This kernel finds fields within the nested message bytes. */ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* parent_row_offsets, cudf::size_type parent_base_offset, field_location const* parent_locations, @@ -521,8 +522,14 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, auto const& parent_loc = parent_locations[row]; if (parent_loc.offset < 0) { return; } - auto parent_row_start = parent_row_offsets[row] - parent_base_offset; - uint8_t const* nested_start = message_data + parent_row_start + parent_loc.offset; + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + int64_t nested_start_off = static_cast(parent_row_start) + parent_loc.offset; + int64_t nested_end_off = nested_start_off + parent_loc.length; + if (nested_start_off < 0 || nested_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + uint8_t const* nested_start = message_data + nested_start_off; uint8_t const* nested_end = nested_start + parent_loc.length; uint8_t const* cur = nested_start; @@ -582,6 +589,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, */ __global__ void scan_repeated_message_children_kernel( uint8_t const* message_data, + cudf::size_type message_data_size, int32_t const* msg_row_offsets, // Row offset for each occurrence field_location const* msg_locs, // Location of each message occurrence (offset within row, length) @@ -602,9 +610,14 @@ __global__ void scan_repeated_message_children_kernel( auto const& msg_loc = msg_locs[occ_idx]; if (msg_loc.offset < 0) return; - // Calculate absolute position of this message in the data - int32_t row_offset = msg_row_offsets[occ_idx]; - uint8_t const* msg_start = message_data + row_offset + msg_loc.offset; + int32_t row_offset = msg_row_offsets[occ_idx]; + int64_t msg_start_off = static_cast(row_offset) + msg_loc.offset; + int64_t msg_end_off = msg_start_off + msg_loc.length; + if (msg_start_off < 0 || msg_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + uint8_t const* msg_start = message_data + msg_start_off; uint8_t const* msg_end = msg_start + msg_loc.length; uint8_t const* cur = msg_start; @@ -830,10 +843,9 @@ __global__ void compute_nested_struct_locations_kernel( int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - // Get the nested struct location from child_locs - nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; - // Compute absolute row offset = msg_row_offset + msg_offset - nested_row_offsets[idx] = msg_row_offsets[idx] + msg_locs[idx].offset; + nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; + auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; + nested_row_offsets[idx] = static_cast(sum); } /** diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index cb0606e0ba..0904d0fcb7 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -165,6 +165,24 @@ private static void validate( "Invalid field number at index " + i + ": " + fieldNumbers[i] + " (must be 1-" + MAX_FIELD_NUMBER + ")"); } + int pi = parentIndices[i]; + if (pi < -1 || pi >= i) { + throw new IllegalArgumentException( + "Invalid parent index at index " + i + ": " + pi + + " (must be -1 or a prior index < " + i + ")"); + } + if (pi == -1) { + if (depthLevels[i] != 0) { + throw new IllegalArgumentException( + "Top-level field at index " + i + " must have depth 0, got " + depthLevels[i]); + } + } else { + if (depthLevels[i] != depthLevels[pi] + 1) { + throw new IllegalArgumentException( + "Field at index " + i + " depth (" + depthLevels[i] + + ") must be parent depth (" + depthLevels[pi] + ") + 1"); + } + } int wt = wireTypes[i]; if (wt != 0 && wt != 1 && wt != 2 && wt != 5) { throw new IllegalArgumentException( From 6e1c4d2154eed75d3b85183d06f72b32fdd2133d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Mar 2026 11:32:16 +0800 Subject: [PATCH 063/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 620d5fa58a..93f3c62036 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -1177,11 +1177,19 @@ std::unique_ptr build_nested_struct_column( stream, mr)); } else { - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + { + int err_val = ERR_MISSING_ENUM_META; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { - CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 1, sizeof(int), stream.value())); + { + int err_val = ERR_MISSING_ENUM_META; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + } struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { From 9c6dd70d02f5803f9bcb520bac2dd011a93ab332 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Mar 2026 12:05:21 +0800 Subject: [PATCH 064/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 5 ++++- src/main/cpp/src/protobuf.cu | 2 +- src/main/cpp/src/protobuf_builders.cu | 16 ++++++++++------ src/main/cpp/src/protobuf_common.cuh | 1 + 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index f8dce47a91..eac0cce59f 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -88,7 +88,10 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, n_encodings.size() != num_fields || n_is_repeated.size() != num_fields || n_is_required.size() != num_fields || n_has_default.size() != num_fields || n_default_ints.size() != num_fields || n_default_floats.size() != num_fields || - n_default_bools.size() != num_fields) { + n_default_bools.size() != num_fields || + env->GetArrayLength(default_strings) != num_fields || + env->GetArrayLength(enum_valid_values) != num_fields || + env->GetArrayLength(enum_names) != num_fields) { JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "All field arrays must have the same length", diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 2ce4fadcd9..111b25a042 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -415,7 +415,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& nf * sizeof(h_descs[0]), cudaMemcpyHostToDevice, stream.value())); - dim3 grid((num_rows + threads - 1) / threads, nf); + dim3 grid((num_rows + threads - 1u) / threads, nf); kernel_fn(grid, threads, stream.value(), diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 93f3c62036..063c93f681 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -272,10 +272,12 @@ std::unique_ptr build_enum_string_column( num_rows); std::vector h_name_offsets(valid_enums.size() + 1, 0); - int32_t total_name_chars = 0; + int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - h_name_offsets[k + 1] = total_name_chars; + total_name_chars += static_cast(enum_name_bytes[k].size()); + CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), + "Enum name data exceeds 2 GB limit"); + h_name_offsets[k + 1] = static_cast(total_name_chars); } std::vector h_name_chars(total_name_chars); int32_t cursor = 0; @@ -368,10 +370,12 @@ std::unique_ptr build_repeated_enum_string_column( stream.value())); std::vector h_name_offsets(valid_enums.size() + 1, 0); - int32_t total_name_chars = 0; + int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - h_name_offsets[k + 1] = total_name_chars; + total_name_chars += static_cast(enum_name_bytes[k].size()); + CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), + "Enum name data exceeds 2 GB limit"); + h_name_offsets[k + 1] = static_cast(total_name_chars); } std::vector h_name_chars(total_name_chars); int32_t cursor = 0; diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index b85110609f..3cf9c02dbf 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -44,6 +44,7 @@ #include #include +#include #include #include From 6b2f494130eef4927919d36eb0d263eb2482e145 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Mar 2026 15:43:51 +0800 Subject: [PATCH 065/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 14 ++++-- src/main/cpp/src/protobuf_common.cuh | 14 ++++-- src/main/cpp/src/protobuf_kernels.cu | 70 ++++++++++++++++++++++----- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 063c93f681..fd89133a2a 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -770,7 +770,8 @@ std::unique_ptr build_repeated_struct_column( base_offset, d_msg_locs.data(), d_msg_row_offsets.data(), - total_count); + total_count, + d_error_top.data()); } thrust::transform(rmm::exec_policy(stream), d_msg_row_offsets.data(), @@ -932,7 +933,8 @@ std::unique_ptr build_repeated_struct_column( num_child_fields, d_nested_locs.data(), d_nested_row_offsets_i32.data(), - total_count); + total_count, + d_error_top.data()); // Add base_offset back so build_nested_struct_column can subtract it thrust::transform(rmm::exec_policy(stream), d_nested_row_offsets_i32.data(), @@ -1283,7 +1285,8 @@ std::unique_ptr build_nested_struct_column( ci, num_child_fields, d_gc_parent.data(), - num_rows); + num_rows, + d_error.data()); struct_children.push_back(build_nested_struct_column(message_data, message_data_size, list_offsets, @@ -1381,6 +1384,7 @@ std::unique_ptr build_repeated_child_list_column( stream.value())); count_repeated_in_nested_kernel<<>>(message_data, + message_data_size, row_offsets, base_offset, parent_locs, @@ -1431,6 +1435,7 @@ std::unique_ptr build_repeated_child_list_column( rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); scan_repeated_in_nested_kernel<<>>(message_data, + message_data_size, row_offsets, base_offset, parent_locs, @@ -1506,7 +1511,8 @@ std::unique_ptr build_repeated_child_list_column( parent_locs, d_virtual_row_offsets.data(), d_virtual_parent_locs.data(), - total_rep_count); + total_rep_count, + d_error.data()); child_values = build_nested_struct_column(message_data, message_data_size, diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 3cf9c02dbf..24095c67a6 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1070,6 +1070,7 @@ __global__ void scan_repeated_message_children_kernel(uint8_t const* message_dat int* error_flag); __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -1082,6 +1083,7 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, int* error_flag); __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -1099,14 +1101,16 @@ __global__ void compute_nested_struct_locations_kernel(field_location const* chi int num_child_fields, field_location* nested_locs, int32_t* nested_row_offsets, - int total_count); + int total_count, + int* error_flag); __global__ void compute_grandchild_parent_locations_kernel(field_location const* parent_locs, field_location const* child_locs, int child_idx, int num_child_fields, field_location* gc_parent_abs, - int num_rows); + int num_rows, + int* error_flag); __global__ void compute_virtual_parents_for_nested_repeated_kernel( repeated_occurrence const* occurrences, @@ -1114,7 +1118,8 @@ __global__ void compute_virtual_parents_for_nested_repeated_kernel( field_location const* parent_locations, cudf::size_type* virtual_row_offsets, field_location* virtual_parent_locs, - int total_count); + int total_count, + int* error_flag); __global__ void compute_msg_locations_from_occurrences_kernel( repeated_occurrence const* occurrences, @@ -1122,7 +1127,8 @@ __global__ void compute_msg_locations_from_occurrences_kernel( cudf::size_type base_offset, field_location* msg_locs, int32_t* msg_row_offsets, - int total_count); + int total_count, + int* error_flag); __global__ void extract_strided_locations_kernel(field_location const* nested_locations, int field_idx, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 1d78df8c67..ed3d2e4b6e 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -703,6 +703,7 @@ __global__ void scan_repeated_message_children_kernel( * fields at the expected child depth. */ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -728,7 +729,14 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, cudf::size_type row_off; row_off = row_offsets[row] - base_offset; - uint8_t const* msg_start = message_data + row_off + parent_loc.offset; + int64_t msg_start_off = static_cast(row_off) + parent_loc.offset; + int64_t msg_end_off = msg_start_off + parent_loc.length; + if (msg_start_off < 0 || msg_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + + uint8_t const* msg_start = message_data + msg_start_off; uint8_t const* msg_end = msg_start + parent_loc.length; uint8_t const* cur = msg_start; @@ -771,6 +779,7 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, * Note: no depth-level check is performed; see count_repeated_in_nested_kernel comment. */ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, cudf::size_type const* row_offsets, cudf::size_type base_offset, field_location const* parent_locs, @@ -789,7 +798,14 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, cudf::size_type row_off = row_offsets[row] - base_offset; - uint8_t const* msg_start = message_data + row_off + parent_loc.offset; + int64_t msg_start_off = static_cast(row_off) + parent_loc.offset; + int64_t msg_end_off = msg_start_off + parent_loc.length; + if (msg_start_off < 0 || msg_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + + uint8_t const* msg_start = message_data + msg_start_off; uint8_t const* msg_end = msg_start + parent_loc.length; uint8_t const* cur = msg_start; @@ -838,13 +854,20 @@ __global__ void compute_nested_struct_locations_kernel( int num_child_fields, // Total number of child fields per occurrence field_location* nested_locs, // Output: nested struct locations int32_t* nested_row_offsets, // Output: nested struct row offsets - int total_count) + int total_count, + int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; - auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; + nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; + auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; + if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + nested_locs[idx] = {-1, 0}; + nested_row_offsets[idx] = 0; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } nested_row_offsets[idx] = static_cast(sum); } @@ -859,7 +882,8 @@ __global__ void compute_grandchild_parent_locations_kernel( int child_idx, // Which child field int num_child_fields, // Total child fields per row field_location* gc_parent_abs, // Output: absolute grandchild parent locations - int num_rows) + int num_rows, + int* error_flag) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= num_rows) return; @@ -869,7 +893,13 @@ __global__ void compute_grandchild_parent_locations_kernel( if (parent_loc.offset >= 0 && child_loc.offset >= 0) { // Absolute offset = parent offset + child's relative offset - gc_parent_abs[row].offset = parent_loc.offset + child_loc.offset; + auto sum = static_cast(parent_loc.offset) + child_loc.offset; + if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + gc_parent_abs[row] = {-1, 0}; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + gc_parent_abs[row].offset = static_cast(sum); gc_parent_abs[row].length = child_loc.length; } else { gc_parent_abs[row] = {-1, 0}; @@ -887,7 +917,8 @@ __global__ void compute_virtual_parents_for_nested_repeated_kernel( field_location const* parent_locations, // parent nested message locations cudf::size_type* virtual_row_offsets, // output: [total_count] field_location* virtual_parent_locs, // output: [total_count] - int total_count) + int total_count, + int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; @@ -901,7 +932,13 @@ __global__ void compute_virtual_parents_for_nested_repeated_kernel( // Protobuf allows an embedded message with length=0, which maps to a non-null // struct with all-null children (not a null struct). if (ploc.offset >= 0) { - virtual_parent_locs[idx] = {ploc.offset + occ.offset, occ.length}; + auto sum = static_cast(ploc.offset) + occ.offset; + if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + virtual_parent_locs[idx] = {-1, 0}; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + virtual_parent_locs[idx] = {static_cast(sum), occ.length}; } else { virtual_parent_locs[idx] = {-1, 0}; } @@ -917,13 +954,22 @@ __global__ void compute_msg_locations_from_occurrences_kernel( cudf::size_type base_offset, // Base offset to subtract field_location* msg_locs, // Output: message locations int32_t* msg_row_offsets, // Output: message row offsets - int total_count) + int total_count, + int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - auto const& occ = occurrences[idx]; - msg_row_offsets[idx] = static_cast(list_offsets[occ.row_idx] - base_offset); + auto const& occ = occurrences[idx]; + auto row_offset = static_cast(list_offsets[occ.row_idx]) - base_offset; + if (row_offset < std::numeric_limits::min() || + row_offset > std::numeric_limits::max()) { + msg_row_offsets[idx] = 0; + msg_locs[idx] = {-1, 0}; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + msg_row_offsets[idx] = static_cast(row_offset); msg_locs[idx] = {occ.offset, occ.length}; } From e40d5a7963618b0d5801c809f0cb2cce335fa4a9 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 6 Mar 2026 16:26:02 +0800 Subject: [PATCH 066/107] Lookup table for repeated message child scan Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 147 ++++++++++++++++++++- src/main/cpp/src/protobuf_builders.cu | 14 +- src/main/cpp/src/protobuf_common.cuh | 4 +- src/main/cpp/src/protobuf_kernels.cu | 87 ++++++------ 4 files changed, 205 insertions(+), 47 deletions(-) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 2f48a431dd..41a80b49ab 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -439,6 +439,115 @@ struct RepeatedFieldCase { } }; +// Case 4: Wide repeated message — stress-tests repeated struct child scanning. +// message WideRepeatedMessage { +// int32 id = 1; +// repeated Item items = 2; +// } +// message Item { +// int32 / int64 / float / double / bool / string child fields ... +// ... (num_child_fields fields) +// } +// +// This case is intentionally generic and contains no customer schema details. +// It is designed to exercise `scan_repeated_message_children_kernel` with a +// wide repeated STRUCT payload similar in shape to real-world schema-projection +// workloads. +struct WideRepeatedMessageCase { + int num_child_fields; + int avg_items_per_row; + + spark_rapids_jni::ProtobufDecodeContext build_context() const + { + spark_rapids_jni::ProtobufDecodeContext ctx; + ctx.fail_on_errors = true; + + // idx 0: id (scalar) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: items (repeated STRUCT) + ctx.schema.push_back({2, -1, 0, 2, cudf::type_id::STRUCT, 0, true, false, false}); + + cudf::type_id child_types[] = {cudf::type_id::INT32, + cudf::type_id::INT64, + cudf::type_id::FLOAT32, + cudf::type_id::FLOAT64, + cudf::type_id::BOOL8, + cudf::type_id::STRING}; + int child_wt[] = {0, 0, 5, 1, 0, 2}; + int child_enc[] = {spark_rapids_jni::ENC_DEFAULT, + spark_rapids_jni::ENC_DEFAULT, + spark_rapids_jni::ENC_FIXED, + spark_rapids_jni::ENC_FIXED, + spark_rapids_jni::ENC_DEFAULT, + spark_rapids_jni::ENC_DEFAULT}; + + // Keep strings sparse so the case remains dominated by wide child scanning + // rather than varlen copy traffic. + for (int i = 0; i < num_child_fields; i++) { + int ti = (i % 10 == 9) ? 5 : (i % 5); + ctx.schema.push_back( + {i + 1, 1, 1, child_wt[ti], child_types[ti], child_enc[ti], false, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(6, 18); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + encode_varint_field(buf, 1, int_dist(rng)); + + int n = vary(avg_items_per_row); + for (int item_idx = 0; item_idx < n; item_idx++) { + encode_nested_message(buf, 2, [&](std::vector& inner) { + for (int i = 0; i < num_child_fields; i++) { + int ti = (i % 10 == 9) ? 5 : (i % 5); + int fn = i + 1; + switch (ti) { + case 0: encode_varint_field(inner, fn, int_dist(rng)); break; + case 1: encode_varint_field(inner, fn, int_dist(rng)); break; + case 2: encode_fixed32_field(inner, fn, static_cast(int_dist(rng))); break; + case 3: encode_fixed64_field(inner, fn, static_cast(int_dist(rng))); break; + case 4: encode_varint_field(inner, fn, rng() % 2); break; + case 5: encode_string_field(inner, fn, random_string(str_len_dist(rng))); break; + } + } + }); + } + } + return messages; + } +}; + // Case 4: Many repeated fields — stress-tests per-repeated-field sync overhead. // message WideRepeatedMessage { // int32 id = 1; @@ -638,7 +747,43 @@ NVBENCH_BENCH(BM_protobuf_repeated) .add_int64_axis("avg_items", {1, 5, 20}); // =========================================================================== -// Benchmark 4: Many repeated fields — measures per-field sync overhead at scale +// Benchmark 4: Wide repeated message — measures repeated struct child scan cost +// =========================================================================== +static void BM_protobuf_wide_repeated_message(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_child_fields = static_cast(state.get_int64("num_child_fields")); + auto const avg_items = static_cast(state.get_int64("avg_items")); + + WideRepeatedMessageCase wide_case{num_child_fields, avg_items}; + auto ctx = wide_case.build_context(); + + std::mt19937 rng(42); + auto messages = wide_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_wide_repeated_message) + .set_name("Protobuf Wide Repeated Message") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_child_fields", {20, 100, 200}) + .add_int64_axis("avg_items", {1, 5, 10}); + +// =========================================================================== +// Benchmark 5: Many repeated fields — measures per-field sync overhead at scale // =========================================================================== static void BM_protobuf_many_repeated(nvbench::state& state) { diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index fd89133a2a..f444f3941e 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -754,6 +754,16 @@ std::unique_ptr build_repeated_struct_column( num_child_fields * sizeof(field_descriptor), cudaMemcpyHostToDevice, stream.value())); + auto h_child_lookup = build_field_lookup_table(h_child_descs.data(), num_child_fields); + rmm::device_uvector d_child_lookup(0, stream, mr); + if (!h_child_lookup.empty()) { + d_child_lookup = rmm::device_uvector(h_child_lookup.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_lookup.data(), + h_child_lookup.data(), + h_child_lookup.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } // For each occurrence, we need to scan for child fields // Create "virtual" parent locations from the occurrences using GPU kernel @@ -798,7 +808,9 @@ std::unique_ptr build_repeated_struct_column( d_child_descs.data(), num_child_fields, d_child_locs.data(), - d_error.data()); + d_error.data(), + h_child_lookup.empty() ? nullptr : d_child_lookup.data(), + static_cast(d_child_lookup.size())); // Note: We no longer need to copy child_locs to host because: // 1. All scalar extraction kernels access d_child_locs directly on device diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 24095c67a6..47577c8bc5 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1067,7 +1067,9 @@ __global__ void scan_repeated_message_children_kernel(uint8_t const* message_dat field_descriptor const* child_descs, int num_child_fields, field_location* child_locs, - int* error_flag); + int* error_flag, + int const* child_lookup = nullptr, + int child_lookup_size = 0); __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, cudf::size_type message_data_size, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index ed3d2e4b6e..8b9d8f32e8 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -597,7 +597,9 @@ __global__ void scan_repeated_message_children_kernel( field_descriptor const* child_descs, int num_child_fields, field_location* child_locs, // Output: [num_occurrences * num_child_fields] - int* error_flag) + int* error_flag, + int const* child_lookup, + int child_lookup_size) { auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (occ_idx >= num_occurrences) return; @@ -628,58 +630,55 @@ __global__ void scan_repeated_message_children_kernel( int fn = tag.field_number; int wt = tag.wire_type; - // Check against child field descriptors - for (int f = 0; f < num_child_fields; f++) { - if (child_descs[f].field_number == fn) { - bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); - if (!is_packed && wt != child_descs[f].expected_wire_type) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return; - } + int f = lookup_field(fn, child_lookup, child_lookup_size, child_descs, num_child_fields); + if (f >= 0) { + bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); + if (!is_packed && wt != child_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } - int data_offset = static_cast(cur - msg_start); + int data_offset = static_cast(cur - msg_start); - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, + static_cast(len)}; + } else { + // For varint/fixed types, store offset and estimated length + int32_t data_length = 0; + if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, msg_end, dummy, vbytes)) { set_error_once(error_flag, ERR_VARINT); return; } - if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - set_error_once(error_flag, ERR_OVERFLOW); + data_length = vbytes; + } else if (wt == WT_32BIT) { + if (msg_end - cur < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); return; } - child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, - static_cast(len)}; - } else { - // For varint/fixed types, store offset and estimated length - int32_t data_length = 0; - if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (!read_varint(cur, msg_end, dummy, vbytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - data_length = vbytes; - } else if (wt == WT_32BIT) { - if (msg_end - cur < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - data_length = 4; - } else if (wt == WT_64BIT) { - if (msg_end - cur < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; - } - data_length = 8; + data_length = 4; + } else if (wt == WT_64BIT) { + if (msg_end - cur < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; } - child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; + data_length = 8; } - // Don't break - last occurrence wins (protobuf semantics) + child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; } } From 4b4f6f95bd59cb82182154bb2bf5eba329a00a3b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 10:03:37 +0800 Subject: [PATCH 067/107] reflection refactor Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 484 +++++++++++++++++- src/main/cpp/src/protobuf.cu | 46 +- src/main/cpp/src/protobuf.hpp | 58 +++ .../rapids/jni/ProtobufSchemaDescriptor.java | 10 + .../jni/ProtobufSchemaDescriptorTest.java | 64 +++ 5 files changed, 638 insertions(+), 24 deletions(-) create mode 100644 src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 41a80b49ab..f2470ac097 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -158,6 +159,21 @@ std::unique_ptr make_binary_column(std::vector& buf, + int field_number, + std::string const& s, + std::vector& out_occurrences, + int32_t row_idx) +{ + encode_tag(buf, field_number, /*WT_LEN=*/2); + encode_varint(buf, s.size()); + auto const data_offset = static_cast(buf.size()); + buf.insert(buf.end(), s.begin(), s.end()); + out_occurrences.push_back({row_idx, data_offset, static_cast(s.size())}); +} // Case 1: Flat scalars only — many top-level scalar fields. // message FlatMessage { @@ -548,7 +564,171 @@ struct WideRepeatedMessageCase { } }; -// Case 4: Many repeated fields — stress-tests per-repeated-field sync overhead. +// Case 5: Repeated child lists — stress-tests repeated fields inside a repeated +// struct child, which exercises build_repeated_child_list_column(). +// message OuterMessage { +// int32 id = 1; +// repeated Item items = 2; +// } +// message Item { +// repeated int32 r_int_* = 1..N +// repeated string r_str_* = ... +// } +// +// This case is intentionally generic and contains no customer schema details. +struct RepeatedChildListCase { + int num_repeated_children; + int avg_items_per_row; + int avg_child_elems; + std::string child_mix; + + bool child_is_string(int child_idx) const + { + if (child_mix == "string_only") return true; + if (child_mix == "int_only") return false; + return (child_idx % 4 == 3); + } + + spark_rapids_jni::ProtobufDecodeContext build_context() const + { + spark_rapids_jni::ProtobufDecodeContext ctx; + ctx.fail_on_errors = true; + + // idx 0: id (scalar) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: items (repeated STRUCT) + ctx.schema.push_back({2, -1, 0, 2, cudf::type_id::STRUCT, 0, true, false, false}); + + for (int i = 0; i < num_repeated_children; i++) { + bool as_string = child_is_string(i); + ctx.schema.push_back({i + 1, + 1, + 1, + as_string ? 2 : 0, + as_string ? cudf::type_id::STRING : cudf::type_id::INT32, + 0, + true, + false, + false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(4, 16); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + encode_varint_field(buf, 1, int_dist(rng)); + + int num_items = vary(avg_items_per_row); + for (int item_idx = 0; item_idx < num_items; item_idx++) { + encode_nested_message(buf, 2, [&](std::vector& inner) { + for (int child_idx = 0; child_idx < num_repeated_children; child_idx++) { + int fn = child_idx + 1; + bool is_str = child_is_string(child_idx); + int num_elems = vary(avg_child_elems); + if (is_str) { + for (int j = 0; j < num_elems; j++) { + encode_string_field(inner, fn, random_string(str_len_dist(rng))); + } + } else { + if (num_elems > 0) { + std::vector vals(num_elems); + for (auto& v : vals) + v = int_dist(rng); + encode_packed_repeated_int32(inner, fn, vals); + } + } + } + }); + } + } + return messages; + } +}; + +struct RepeatedChildStringBenchData { + std::vector> messages; + std::vector parent_locs; + std::vector> counts_by_child; + std::vector> occurrences_by_child; +}; + +struct RepeatedChildStringOnlyCase { + int num_repeated_children; + int avg_child_elems; + + RepeatedChildStringBenchData generate_data(int num_rows, std::mt19937& rng) const + { + RepeatedChildStringBenchData out; + out.messages.resize(num_rows); + out.parent_locs.resize(num_rows); + out.counts_by_child.resize(num_repeated_children); + out.occurrences_by_child.resize(num_repeated_children); + + std::uniform_int_distribution str_len_dist(4, 16); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int row = 0; row < num_rows; row++) { + auto& buf = out.messages[row]; + for (int child_idx = 0; child_idx < num_repeated_children; child_idx++) { + int fn = child_idx + 1; + int num_elems = vary(avg_child_elems); + out.counts_by_child[child_idx].push_back(num_elems); + for (int j = 0; j < num_elems; j++) { + encode_string_field_record( + buf, fn, random_string(str_len_dist(rng)), out.occurrences_by_child[child_idx], row); + } + } + out.parent_locs[row] = {0, static_cast(buf.size())}; + } + return out; + } +}; + +// Case 6: Many repeated fields — stress-tests per-repeated-field sync overhead. // message WideRepeatedMessage { // int32 id = 1; // repeated int32 r_int_1 = 2; @@ -783,7 +963,307 @@ NVBENCH_BENCH(BM_protobuf_wide_repeated_message) .add_int64_axis("avg_items", {1, 5, 10}); // =========================================================================== -// Benchmark 5: Many repeated fields — measures per-field sync overhead at scale +// Benchmark 5: Repeated child lists — measures repeated-in-nested list overhead +// =========================================================================== +static void BM_protobuf_repeated_child_lists(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_repeated_children = static_cast(state.get_int64("num_repeated_children")); + auto const avg_items = static_cast(state.get_int64("avg_items")); + auto const avg_child_elems = static_cast(state.get_int64("avg_child_elems")); + auto const child_mix = state.get_string("child_mix"); + + RepeatedChildListCase list_case{ + num_repeated_children, avg_items, avg_child_elems, std::string(child_mix)}; + auto ctx = list_case.build_context(); + + std::mt19937 rng(42); + auto messages = list_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated_child_lists) + .set_name("Protobuf Repeated Child Lists") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_repeated_children", {1, 4, 8}) + .add_int64_axis("avg_items", {1, 5}) + .add_int64_axis("avg_child_elems", {1, 5}) + .add_string_axis("child_mix", {"int_only", "mixed", "string_only"}); + +// =========================================================================== +// Benchmark 6: Repeated child string count+scan only +// =========================================================================== +static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_repeated_children = static_cast(state.get_int64("num_repeated_children")); + auto const avg_child_elems = static_cast(state.get_int64("avg_child_elems")); + + RepeatedChildStringOnlyCase list_case{num_repeated_children, avg_child_elems}; + std::mt19937 rng(42); + auto data = list_case.generate_data(num_rows, rng); + auto binary_col = make_binary_column(data.messages); + + cudf::lists_column_view in_list(binary_col->view()); + auto const* row_offsets = in_list.offsets().data(); + auto const* message_data = + reinterpret_cast(in_list.child().data()); + auto const message_data_size = static_cast(in_list.child().size()); + + auto stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + rmm::device_uvector d_parent_locs(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), + data.parent_locs.data(), + num_rows * sizeof(pb_field_location), + cudaMemcpyHostToDevice, + stream.value())); + + std::vector h_schema( + num_repeated_children); + for (int i = 0; i < num_repeated_children; i++) { + h_schema[i].field_number = i + 1; + h_schema[i].parent_idx = -1; + h_schema[i].depth = 0; + h_schema[i].wire_type = spark_rapids_jni::protobuf_detail::WT_LEN; + h_schema[i].output_type_id = static_cast(cudf::type_id::STRING); + h_schema[i].encoding = 0; + h_schema[i].is_repeated = true; + h_schema[i].is_required = false; + h_schema[i].has_default_value = false; + } + rmm::device_uvector d_schema( + num_repeated_children, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), + h_schema.data(), + num_repeated_children * sizeof(h_schema[0]), + cudaMemcpyHostToDevice, + stream.value())); + + std::vector h_rep_indices(num_repeated_children); + for (int i = 0; i < num_repeated_children; i++) { + h_rep_indices[i] = i; + } + rmm::device_uvector d_rep_indices(num_repeated_children, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), + h_rep_indices.data(), + num_repeated_children * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + + size_t total_bytes = 0; + for (auto const& m : data.messages) + total_bytes += m.size(); + + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + rmm::device_uvector d_rep_info( + static_cast(num_rows) * num_repeated_children, stream, mr); + spark_rapids_jni::protobuf_detail::count_repeated_in_nested_kernel<<< + (num_rows + 255) / 256, 256, 0, stream.value()>>>(message_data, + message_data_size, + row_offsets, + 0, + d_parent_locs.data(), + num_rows, + d_schema.data(), + num_repeated_children, + d_rep_info.data(), + num_repeated_children, + d_rep_indices.data(), + d_error.data()); + + struct rep_work { + rmm::device_uvector counts; + rmm::device_uvector offsets; + int32_t total_count{0}; + std::unique_ptr> occs; + rep_work(int n, rmm::cuda_stream_view s, rmm::device_async_resource_ref m) + : counts(n, s, m), offsets(n + 1, s, m) + { + } + }; + + std::vector> work; + work.reserve(num_repeated_children); + for (int ri = 0; ri < num_repeated_children; ri++) { + auto& w = *work.emplace_back(std::make_unique(num_rows, stream, mr)); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + w.counts.data(), + spark_rapids_jni::protobuf_detail::extract_strided_count{ + d_rep_info.data(), ri, num_repeated_children}); + CUDF_CUDA_TRY(cudaMemsetAsync(w.offsets.data(), 0, sizeof(int32_t), stream.value())); + thrust::inclusive_scan( + rmm::exec_policy(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); + CUDF_CUDA_TRY(cudaMemcpyAsync(&w.total_count, + w.offsets.data() + num_rows, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + } + stream.synchronize(); + + for (int ri = 0; ri < num_repeated_children; ri++) { + auto& w = *work[ri]; + if (w.total_count > 0) { + w.occs = std::make_unique>(w.total_count, + stream, + mr); + spark_rapids_jni::protobuf_detail::scan_repeated_in_nested_kernel<<< + (num_rows + 255) / 256, 256, 0, stream.value()>>>(message_data, + message_data_size, + row_offsets, + 0, + d_parent_locs.data(), + num_rows, + d_schema.data(), + w.offsets.data(), + d_rep_indices.data() + ri, + w.occs->data(), + d_error.data()); + } + } + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated_child_string_count_scan) + .set_name("Protobuf Repeated Child String CountScan") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_repeated_children", {1, 4, 8}) + .add_int64_axis("avg_child_elems", {1, 5}); + +// =========================================================================== +// Benchmark 7: Repeated child string build-only +// =========================================================================== +static void BM_protobuf_repeated_child_string_build(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_repeated_children = static_cast(state.get_int64("num_repeated_children")); + auto const avg_child_elems = static_cast(state.get_int64("avg_child_elems")); + + RepeatedChildStringOnlyCase list_case{num_repeated_children, avg_child_elems}; + std::mt19937 rng(42); + auto data = list_case.generate_data(num_rows, rng); + auto binary_col = make_binary_column(data.messages); + + cudf::lists_column_view in_list(binary_col->view()); + auto const* row_offsets = in_list.offsets().data(); + auto const* message_data = + reinterpret_cast(in_list.child().data()); + + auto stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + rmm::device_uvector d_parent_locs(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), + data.parent_locs.data(), + num_rows * sizeof(pb_field_location), + cudaMemcpyHostToDevice, + stream.value())); + + struct precomputed_child { + rmm::device_uvector counts; + rmm::device_uvector occs; + int total_count; + precomputed_child(int nrows, + int total, + rmm::cuda_stream_view s, + rmm::device_async_resource_ref m) + : counts(nrows, s, m), occs(total, s, m), total_count(total) + { + } + }; + + std::vector> children; + children.reserve(num_repeated_children); + for (int i = 0; i < num_repeated_children; i++) { + int total = static_cast(data.occurrences_by_child[i].size()); + auto& c = + *children.emplace_back(std::make_unique(num_rows, total, stream, mr)); + CUDF_CUDA_TRY(cudaMemcpyAsync(c.counts.data(), + data.counts_by_child[i].data(), + num_rows * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + if (total > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(c.occs.data(), + data.occurrences_by_child[i].data(), + total * sizeof(pb_repeated_occurrence), + cudaMemcpyHostToDevice, + stream.value())); + } + } + + size_t total_bytes = 0; + for (auto const& m : data.messages) + total_bytes += m.size(); + + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + for (int i = 0; i < num_repeated_children; i++) { + auto& c = *children[i]; + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy(stream), c.counts.begin(), c.counts.end(), list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &c.total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + spark_rapids_jni::protobuf_detail::NestedRepeatedLocationProvider nr_loc{ + row_offsets, 0, d_parent_locs.data(), c.occs.data()}; + auto valid_fn = [] __device__(cudf::size_type) { return true; }; + std::vector empty_default; + auto child_values = spark_rapids_jni::protobuf_detail::extract_and_build_string_or_bytes_column( + false, message_data, c.total_count, nr_loc, nr_loc, valid_fn, false, empty_default, d_error, stream, mr); + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto result = cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}); + } + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated_child_string_build) + .set_name("Protobuf Repeated Child String Build") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_repeated_children", {1, 4, 8}) + .add_int64_axis("avg_child_elems", {1, 5}); + +// =========================================================================== +// Benchmark 8: Many repeated fields — measures per-field sync overhead at scale // =========================================================================== static void BM_protobuf_many_repeated(nvbench::state& state) { diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 111b25a042..5c6e947044 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -26,6 +26,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& ProtobufDecodeContext const& context, rmm::cuda_stream_view stream) { + validate_decode_context(context); auto const& schema = context.schema; auto const& schema_output_types = context.schema_output_types; auto const& default_ints = context.default_ints; @@ -462,10 +463,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Per-field fallback (INT32 with enum, etc.) for (int i : group_lists[GRP_FALLBACK]) { - int schema_idx = scalar_field_indices[i]; - auto const dt = schema_output_types[schema_idx]; - auto const enc = schema[schema_idx].encoding; - bool has_def = schema[schema_idx].has_default_value; + int schema_idx = scalar_field_indices[i]; + auto const field_meta = make_field_meta_view(context, schema_idx); + auto const dt = field_meta.output_type; + auto const enc = field_meta.schema.encoding; + bool has_def = field_meta.schema.has_default_value; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; column_map[schema_idx] = extract_typed_column(dt, @@ -476,10 +478,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& blocks, threads, has_def, - has_def ? default_ints[schema_idx] : 0, - has_def ? default_floats[schema_idx] : 0.0, - has_def ? default_bools[schema_idx] : false, - default_strings[schema_idx], + has_def ? field_meta.default_int : 0, + has_def ? field_meta.default_float : 0.0, + has_def ? field_meta.default_bool : false, + field_meta.default_string, schema_idx, enum_valid_values, enum_names, @@ -492,11 +494,12 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Per-field extraction for STRING and LIST types for (int i = 0; i < num_scalar; i++) { - int schema_idx = scalar_field_indices[i]; - auto const dt = schema_output_types[schema_idx]; + int schema_idx = scalar_field_indices[i]; + auto const field_meta = make_field_meta_view(context, schema_idx); + auto const dt = field_meta.output_type; if (dt.id() != cudf::type_id::STRING && dt.id() != cudf::type_id::LIST) continue; - auto const enc = schema[schema_idx].encoding; - bool has_def = schema[schema_idx].has_default_value; + auto const enc = field_meta.schema.encoding; + bool has_def = field_meta.schema.has_default_value; switch (dt.id()) { case cudf::type_id::STRING: { @@ -507,7 +510,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // 3. Convert INT32 -> UTF-8 enum name bytes. rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - int64_t def_int = has_def ? default_ints[schema_idx] : 0; + int64_t def_int = has_def ? field_meta.default_int : 0; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; extract_varint_kernel @@ -553,7 +556,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { // Regular protobuf STRING (length-delimited) bool has_def_str = has_def; - auto const& def_str = default_strings[schema_idx]; + auto const& def_str = field_meta.default_string; TopLevelLocationProvider len_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; TopLevelLocationProvider copy_provider{ @@ -579,7 +582,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::LIST: { // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; - auto const& def_bytes = default_strings[schema_idx]; + auto const& def_bytes = field_meta.default_string; TopLevelLocationProvider len_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; TopLevelLocationProvider copy_provider{ @@ -832,12 +835,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr); break; case cudf::type_id::STRING: { - auto enc = schema[schema_idx].encoding; + auto const field_meta = make_field_meta_view(context, schema_idx); + auto enc = field_meta.schema.encoding; if (enc == spark_rapids_jni::ENC_ENUM_STRING) { - if (schema_idx < static_cast(enum_valid_values.size()) && - schema_idx < static_cast(enum_names.size()) && - !enum_valid_values[schema_idx].empty() && - enum_valid_values[schema_idx].size() == enum_names[schema_idx].size()) { + if (!field_meta.enum_valid_values.empty() && + field_meta.enum_valid_values.size() == field_meta.enum_names.size()) { column_map[schema_idx] = build_repeated_enum_string_column(binary_input, message_data, @@ -847,8 +849,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_occurrences, total_count, num_rows, - enum_valid_values[schema_idx], - enum_names[schema_idx], + field_meta.enum_valid_values, + field_meta.enum_names, d_row_has_invalid_enum, d_error, stream, diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index a42ec6e80a..5623f0e624 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -23,6 +23,8 @@ #include #include +#include +#include #include namespace spark_rapids_jni { @@ -67,6 +69,62 @@ struct ProtobufDecodeContext { bool fail_on_errors; }; +struct ProtobufFieldMetaView { + nested_field_descriptor const& schema; + cudf::data_type const& output_type; + int64_t default_int; + double default_float; + bool default_bool; + std::vector const& default_string; + std::vector const& enum_valid_values; + std::vector> const& enum_names; +}; + +inline void validate_decode_context(ProtobufDecodeContext const& context) +{ + auto const num_fields = context.schema.size(); + auto const fail_size = [&](char const* name, size_t actual) { + throw std::invalid_argument(std::string("protobuf decode context: ") + name + + " size mismatch with schema (" + std::to_string(actual) + " vs " + + std::to_string(num_fields) + ")"); + }; + + if (context.schema_output_types.size() != num_fields) fail_size("schema_output_types", + context.schema_output_types.size()); + if (context.default_ints.size() != num_fields) fail_size("default_ints", context.default_ints.size()); + if (context.default_floats.size() != num_fields) + fail_size("default_floats", context.default_floats.size()); + if (context.default_bools.size() != num_fields) + fail_size("default_bools", context.default_bools.size()); + if (context.default_strings.size() != num_fields) + fail_size("default_strings", context.default_strings.size()); + if (context.enum_valid_values.size() != num_fields) + fail_size("enum_valid_values", context.enum_valid_values.size()); + if (context.enum_names.size() != num_fields) fail_size("enum_names", context.enum_names.size()); + + for (size_t i = 0; i < num_fields; ++i) { + auto const& field = context.schema[i]; + if (field.encoding == ENC_ENUM_STRING && + context.enum_valid_values[i].size() != context.enum_names[i].size()) { + throw std::invalid_argument("protobuf decode context: enum-as-string metadata mismatch at field " + + std::to_string(i)); + } + } +} + +inline ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, int schema_idx) +{ + auto const idx = static_cast(schema_idx); + return ProtobufFieldMetaView{context.schema.at(idx), + context.schema_output_types.at(idx), + context.default_ints.at(idx), + context.default_floats.at(idx), + context.default_bools.at(idx), + context.default_strings.at(idx), + context.enum_valid_values.at(idx), + context.enum_names.at(idx)}; +} + /** * Decode protobuf messages (one message per row) from a LIST column into a STRUCT * column, with support for nested messages and repeated fields. diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 0904d0fcb7..fbaa44c86b 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -194,6 +194,16 @@ private static void validate( throw new IllegalArgumentException( "Invalid encoding at index " + i + ": " + enc); } + if (isRepeated[i] && hasDefaultValue[i]) { + throw new IllegalArgumentException( + "Repeated field at index " + i + " cannot carry a default value"); + } + if (enc == Protobuf.ENC_ENUM_STRING && + (enumValidValues[i] == null || enumNames[i] == null)) { + throw new IllegalArgumentException( + "Enum-as-string field at index " + i + + " must provide both enumValidValues and enumNames"); + } if (enumValidValues[i] != null) { int[] ev = enumValidValues[i]; for (int j = 1; j < ev.length; j++) { diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java new file mode 100644 index 0000000000..f83f432d93 --- /dev/null +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.nvidia.spark.rapids.jni; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ProtobufSchemaDescriptorTest { + private ProtobufSchemaDescriptor makeDescriptor( + boolean isRepeated, + boolean hasDefaultValue, + int encoding, + int[] enumValidValues, + byte[][] enumNames) { + return new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{encoding}, + new boolean[]{isRepeated}, + new boolean[]{false}, + new boolean[]{hasDefaultValue}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{enumValidValues}, + new byte[][][]{enumNames}); + } + + @Test + void testRepeatedFieldCannotCarryDefaultValue() { + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(true, true, Protobuf.ENC_DEFAULT, null, null)); + } + + @Test + void testEnumStringRequiresEnumMetadata() { + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, null, null)); + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, new int[]{0, 1}, null)); + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, null, + new byte[][]{"A".getBytes(), "B".getBytes()})); + } +} From 66daed1210a7392dda012e3564eb311cd6a548ac Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 10:03:49 +0800 Subject: [PATCH 068/107] reflection refactor Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 93 ++++++++++++---------- src/main/cpp/src/protobuf.hpp | 14 ++-- 2 files changed, 59 insertions(+), 48 deletions(-) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index f2470ac097..334aac1bdf 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -158,8 +158,8 @@ std::unique_ptr make_binary_column(std::vector& buf, @@ -653,8 +653,8 @@ struct RepeatedChildListCase { for (int item_idx = 0; item_idx < num_items; item_idx++) { encode_nested_message(buf, 2, [&](std::vector& inner) { for (int child_idx = 0; child_idx < num_repeated_children; child_idx++) { - int fn = child_idx + 1; - bool is_str = child_is_string(child_idx); + int fn = child_idx + 1; + bool is_str = child_is_string(child_idx); int num_elems = vary(avg_child_elems); if (is_str) { for (int j = 0; j < num_elems; j++) { @@ -714,7 +714,7 @@ struct RepeatedChildStringOnlyCase { for (int row = 0; row < num_rows; row++) { auto& buf = out.messages[row]; for (int child_idx = 0; child_idx < num_repeated_children; child_idx++) { - int fn = child_idx + 1; + int fn = child_idx + 1; int num_elems = vary(avg_child_elems); out.counts_by_child[child_idx].push_back(num_elems); for (int j = 0; j < num_elems; j++) { @@ -1018,9 +1018,8 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) auto binary_col = make_binary_column(data.messages); cudf::lists_column_view in_list(binary_col->view()); - auto const* row_offsets = in_list.offsets().data(); - auto const* message_data = - reinterpret_cast(in_list.child().data()); + auto const* row_offsets = in_list.offsets().data(); + auto const* message_data = reinterpret_cast(in_list.child().data()); auto const message_data_size = static_cast(in_list.child().size()); auto stream = cudf::get_default_stream(); @@ -1076,19 +1075,20 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) rmm::device_uvector d_rep_info( static_cast(num_rows) * num_repeated_children, stream, mr); - spark_rapids_jni::protobuf_detail::count_repeated_in_nested_kernel<<< - (num_rows + 255) / 256, 256, 0, stream.value()>>>(message_data, - message_data_size, - row_offsets, - 0, - d_parent_locs.data(), - num_rows, - d_schema.data(), - num_repeated_children, - d_rep_info.data(), - num_repeated_children, - d_rep_indices.data(), - d_error.data()); + spark_rapids_jni::protobuf_detail:: + count_repeated_in_nested_kernel<<<(num_rows + 255) / 256, 256, 0, stream.value()>>>( + message_data, + message_data_size, + row_offsets, + 0, + d_parent_locs.data(), + num_rows, + d_schema.data(), + num_repeated_children, + d_rep_info.data(), + num_repeated_children, + d_rep_indices.data(), + d_error.data()); struct rep_work { rmm::device_uvector counts; @@ -1125,21 +1125,21 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) for (int ri = 0; ri < num_repeated_children; ri++) { auto& w = *work[ri]; if (w.total_count > 0) { - w.occs = std::make_unique>(w.total_count, - stream, - mr); - spark_rapids_jni::protobuf_detail::scan_repeated_in_nested_kernel<<< - (num_rows + 255) / 256, 256, 0, stream.value()>>>(message_data, - message_data_size, - row_offsets, - 0, - d_parent_locs.data(), - num_rows, - d_schema.data(), - w.offsets.data(), - d_rep_indices.data() + ri, - w.occs->data(), - d_error.data()); + w.occs = + std::make_unique>(w.total_count, stream, mr); + spark_rapids_jni::protobuf_detail:: + scan_repeated_in_nested_kernel<<<(num_rows + 255) / 256, 256, 0, stream.value()>>>( + message_data, + message_data_size, + row_offsets, + 0, + d_parent_locs.data(), + num_rows, + d_schema.data(), + w.offsets.data(), + d_rep_indices.data() + ri, + w.occs->data(), + d_error.data()); } } }); @@ -1169,9 +1169,8 @@ static void BM_protobuf_repeated_child_string_build(nvbench::state& state) auto binary_col = make_binary_column(data.messages); cudf::lists_column_view in_list(binary_col->view()); - auto const* row_offsets = in_list.offsets().data(); - auto const* message_data = - reinterpret_cast(in_list.child().data()); + auto const* row_offsets = in_list.offsets().data(); + auto const* message_data = reinterpret_cast(in_list.child().data()); auto stream = cudf::get_default_stream(); auto mr = cudf::get_current_device_resource_ref(); @@ -1240,14 +1239,24 @@ static void BM_protobuf_repeated_child_string_build(nvbench::state& state) row_offsets, 0, d_parent_locs.data(), c.occs.data()}; auto valid_fn = [] __device__(cudf::size_type) { return true; }; std::vector empty_default; - auto child_values = spark_rapids_jni::protobuf_detail::extract_and_build_string_or_bytes_column( - false, message_data, c.total_count, nr_loc, nr_loc, valid_fn, false, empty_default, d_error, stream, mr); + auto child_values = + spark_rapids_jni::protobuf_detail::extract_and_build_string_or_bytes_column(false, + message_data, + c.total_count, + nr_loc, + nr_loc, + valid_fn, + false, + empty_default, + d_error, + stream, + mr); auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, list_offs.release(), rmm::device_buffer{}, 0); - auto result = cudf::make_lists_column( + auto result = cudf::make_lists_column( num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}); } }); diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 5623f0e624..c5c53d019f 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -89,9 +89,10 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) std::to_string(num_fields) + ")"); }; - if (context.schema_output_types.size() != num_fields) fail_size("schema_output_types", - context.schema_output_types.size()); - if (context.default_ints.size() != num_fields) fail_size("default_ints", context.default_ints.size()); + if (context.schema_output_types.size() != num_fields) + fail_size("schema_output_types", context.schema_output_types.size()); + if (context.default_ints.size() != num_fields) + fail_size("default_ints", context.default_ints.size()); if (context.default_floats.size() != num_fields) fail_size("default_floats", context.default_floats.size()); if (context.default_bools.size() != num_fields) @@ -106,13 +107,14 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) auto const& field = context.schema[i]; if (field.encoding == ENC_ENUM_STRING && context.enum_valid_values[i].size() != context.enum_names[i].size()) { - throw std::invalid_argument("protobuf decode context: enum-as-string metadata mismatch at field " + - std::to_string(i)); + throw std::invalid_argument( + "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i)); } } } -inline ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, int schema_idx) +inline ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, + int schema_idx) { auto const idx = static_cast(schema_idx); return ProtobufFieldMetaView{context.schema.at(idx), From cd1763b38fdff908ef01703596f369953675b0b5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 11:23:35 +0800 Subject: [PATCH 069/107] bug fixes Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 34 +- src/main/cpp/src/protobuf_builders.cu | 354 ++++++++++++++++-- src/main/cpp/src/protobuf_common.cuh | 32 ++ .../nvidia/spark/rapids/jni/ProtobufTest.java | 244 ++++++++++++ 4 files changed, 614 insertions(+), 50 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 5c6e947044..ecff747532 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -271,31 +271,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_locations.data(), d_error.data()); - // Check required fields (after scan pass) - { - bool has_required = false; - for (int i = 0; i < num_scalar; i++) { - int si = scalar_field_indices[i]; - if (schema[si].is_required) { - has_required = true; - break; - } - } - if (has_required) { - rmm::device_uvector d_is_required(num_scalar, stream, mr); - std::vector h_is_required(num_scalar); - for (int i = 0; i < num_scalar; i++) { - h_is_required[i] = schema[scalar_field_indices[i]].is_required ? 1 : 0; - } - CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), - h_is_required.data(), - num_scalar * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - check_required_fields_kernel<<>>( - d_locations.data(), d_is_required.data(), num_scalar, num_rows, d_error.data()); - } - } + // Required-field validation applies to all scalar leaves, not just top-level numerics. + maybe_check_required_fields( + d_locations.data(), scalar_field_indices, schema, num_rows, d_error.data(), stream, mr); // Batched scalar extraction: group non-special fixed-width fields by extraction // category and extract all fields of each category with a single 2D kernel launch. @@ -612,6 +590,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } + // Required top-level nested messages are tracked in d_nested_locations during the scan/count + // pass. + maybe_check_required_fields( + d_nested_locations.data(), nested_field_indices, schema, num_rows, d_error.data(), stream, mr); + // Process repeated fields (three-phase: offsets → combined scan → build columns) if (num_repeated > 0) { // Phase A: Compute per-row offsets for each repeated field. @@ -1013,6 +996,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_rows, stream, mr, + nullptr, 0); } } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index f444f3941e..ecbd27d86b 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -335,6 +335,133 @@ std::unique_ptr build_enum_string_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); } +inline std::unique_ptr build_repeated_msg_child_enum_string_column( + uint8_t const* message_data, + rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_locs, + rmm::device_uvector const& d_child_locs, + int child_idx, + int num_child_fields, + int total_count, + int32_t const* top_row_indices, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + + rmm::device_uvector enum_values(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); + RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, enum_values.data(), valid.data(), d_error.data()); + + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); + thrust::fill(rmm::exec_policy(stream), + d_elem_has_invalid_enum.begin(), + d_elem_has_invalid_enum.end(), + false); + validate_enum_values_kernel<<>>( + enum_values.data(), + valid.data(), + d_elem_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + total_count); + + if (d_row_has_invalid_enum.size() > 0 && total_count > 0) { + auto const* elem_invalid = d_elem_has_invalid_enum.data(); + auto const* parent_rows = top_row_indices; + auto* row_invalid = d_row_has_invalid_enum.data(); + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + [elem_invalid, parent_rows, row_invalid] __device__(int idx) { + if (elem_invalid[idx]) { row_invalid[parent_rows[idx]] = true; } + }); + } + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int64_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), + "Enum name data exceeds 2 GB limit"); + h_name_offsets[k + 1] = static_cast(total_name_chars); + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); + cursor += static_cast(name.size()); + } + } + + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + } + + rmm::device_uvector lengths(total_count, stream, mr); + compute_enum_string_lengths_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + lengths.data(), + total_count); + + auto [offsets_col, total_chars] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + offsets_col->view().data(), + chars.data(), + total_count); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return cudf::make_strings_column( + total_count, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + std::unique_ptr build_repeated_enum_string_column( cudf::column_view const& binary_input, uint8_t const* message_data, @@ -631,6 +758,7 @@ std::unique_ptr build_nested_struct_column( int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, int depth); // Forward declaration -- build_repeated_child_list_column is defined after @@ -657,6 +785,7 @@ std::unique_ptr build_repeated_child_list_column( rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, int depth); std::unique_ptr build_repeated_struct_column( @@ -788,6 +917,12 @@ std::unique_ptr build_repeated_struct_column( d_msg_row_offsets.end(), d_msg_row_offsets_size.data(), [] __device__(int32_t v) { return static_cast(v); }); + rmm::device_uvector d_top_row_indices(total_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_occurrences.data(), + d_occurrences.end(), + d_top_row_indices.data(), + [] __device__(repeated_occurrence const& occ) { return occ.row_idx; }); // Scan for child fields within each message occurrence rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); @@ -812,6 +947,10 @@ std::unique_ptr build_repeated_struct_column( h_child_lookup.empty() ? nullptr : d_child_lookup.data(), static_cast(d_child_lookup.size())); + // Enforce proto2 required semantics for fields inside each repeated message occurrence. + maybe_check_required_fields( + d_child_locs.data(), child_field_indices, schema, total_count, d_error.data(), stream, mr); + // Note: We no longer need to copy child_locs to host because: // 1. All scalar extraction kernels access d_child_locs directly on device // 2. String extraction uses GPU kernels @@ -848,6 +987,7 @@ std::unique_ptr build_repeated_struct_column( d_error_top, stream, mr, + d_top_row_indices.data(), 1)); continue; } @@ -889,17 +1029,45 @@ std::unique_ptr build_repeated_struct_column( break; } case cudf::type_id::STRING: { - struct_children.push_back(build_repeated_msg_child_varlen_column(message_data, - d_msg_row_offsets, - d_msg_locs, - d_child_locs, - ci, - num_child_fields, - total_count, - d_error, - false, - stream, - mr)); + if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size()) && + !enum_valid_values[child_schema_idx].empty() && + enum_valid_values[child_schema_idx].size() == enum_names[child_schema_idx].size()) { + struct_children.push_back( + build_repeated_msg_child_enum_string_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_top_row_indices.data(), + enum_valid_values[child_schema_idx], + enum_names[child_schema_idx], + d_row_has_invalid_enum, + d_error, + stream, + mr)); + } else { + int err_val = ERR_MISSING_ENUM_META; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + struct_children.push_back(make_null_column(dt, total_count, stream, mr)); + } + } else { + struct_children.push_back(build_repeated_msg_child_varlen_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + false, + stream, + mr)); + } break; } case cudf::type_id::LIST: { @@ -977,6 +1145,7 @@ std::unique_ptr build_repeated_struct_column( total_count, stream, mr, + d_top_row_indices.data(), 0)); } break; @@ -1034,6 +1203,7 @@ std::unique_ptr build_nested_struct_column( int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, int depth) { CUDF_EXPECTS(depth < MAX_NESTED_STRUCT_DECODE_DEPTH, @@ -1091,6 +1261,10 @@ std::unique_ptr build_nested_struct_column( d_child_locations.data(), d_error.data()); + // Enforce proto2 required semantics for direct children of this nested message. + maybe_check_required_fields( + d_child_locations.data(), child_field_indices, schema, num_rows, d_error.data(), stream, mr); + std::vector> struct_children; for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; @@ -1120,6 +1294,7 @@ std::unique_ptr build_nested_struct_column( d_error, stream, mr, + top_row_indices, depth)); continue; } @@ -1319,6 +1494,7 @@ std::unique_ptr build_nested_struct_column( num_rows, stream, mr, + top_row_indices, depth + 1)); break; } @@ -1363,6 +1539,7 @@ std::unique_ptr build_repeated_child_list_column( rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, int depth) { auto const threads = THREADS_PER_BLOCK; @@ -1487,20 +1664,146 @@ std::unique_ptr build_repeated_child_list_column( stream, mr); } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { - bool as_bytes = (elem_type_id == cudf::type_id::LIST); - auto valid_fn = [] __device__(cudf::size_type) { return true; }; - std::vector empty_default; - child_values = extract_and_build_string_or_bytes_column(as_bytes, - message_data, - total_rep_count, - nr_loc, - nr_loc, - valid_fn, - false, - empty_default, - d_error, - stream, - mr); + if (elem_type_id == cudf::type_id::STRING && + schema[child_schema_idx].encoding == spark_rapids_jni::ENC_ENUM_STRING) { + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size()) && + !enum_valid_values[child_schema_idx].empty() && + enum_valid_values[child_schema_idx].size() == enum_names[child_schema_idx].size()) { + rmm::device_uvector enum_values(total_rep_count, stream, mr); + rmm::device_uvector valid((total_rep_count > 0 ? total_rep_count : 1), stream, mr); + extract_varint_kernel + <<>>(message_data, + nr_loc, + total_rep_count, + enum_values.data(), + valid.data(), + d_error.data()); + + auto const& valid_enums = enum_valid_values[child_schema_idx]; + auto const& enum_name_bytes = enum_names[child_schema_idx]; + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_elem_has_invalid_enum(total_rep_count, stream, mr); + thrust::fill(rmm::exec_policy(stream), + d_elem_has_invalid_enum.begin(), + d_elem_has_invalid_enum.end(), + false); + validate_enum_values_kernel<<>>( + enum_values.data(), + valid.data(), + d_elem_has_invalid_enum.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + total_rep_count); + + if (d_row_has_invalid_enum.size() > 0) { + auto const* rep_occs = d_rep_occs.data(); + auto const* parent_rows = top_row_indices; + auto const* elem_invalid = d_elem_has_invalid_enum.data(); + auto* row_invalid = d_row_has_invalid_enum.data(); + thrust::for_each(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_rep_count), + [rep_occs, parent_rows, elem_invalid, row_invalid] __device__(int idx) { + if (elem_invalid[idx]) { + auto const parent_row = rep_occs[idx].row_idx; + auto const top_row = + parent_rows != nullptr ? parent_rows[parent_row] : parent_row; + row_invalid[top_row] = true; + } + }); + } + + std::vector h_name_offsets(valid_enums.size() + 1, 0); + int64_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), + "Enum name data exceeds 2 GB limit"); + h_name_offsets[k + 1] = static_cast(total_name_chars); + } + std::vector h_name_chars(total_name_chars); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); + cursor += static_cast(name.size()); + } + } + + rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), + h_name_offsets.data(), + h_name_offsets.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); + if (total_name_chars > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), + h_name_chars.data(), + total_name_chars * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + } + + rmm::device_uvector lengths(total_rep_count, stream, mr); + compute_enum_string_lengths_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + static_cast(valid_enums.size()), + lengths.data(), + total_rep_count); + + auto [offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( + lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + copy_enum_string_chars_kernel<<>>( + enum_values.data(), + valid.data(), + d_valid_enums.data(), + d_name_offsets.data(), + d_name_chars.data(), + static_cast(valid_enums.size()), + offsets_col->view().data(), + chars.data(), + total_rep_count); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + child_values = cudf::make_strings_column( + total_rep_count, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + } else { + int err_val = ERR_MISSING_ENUM_META; + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + child_values = make_null_column(cudf::data_type{elem_type_id}, total_rep_count, stream, mr); + } + } else { + bool as_bytes = (elem_type_id == cudf::type_id::LIST); + auto valid_fn = [] __device__(cudf::size_type) { return true; }; + std::vector empty_default; + child_values = extract_and_build_string_or_bytes_column(as_bytes, + message_data, + total_rep_count, + nr_loc, + nr_loc, + valid_fn, + false, + empty_default, + d_error, + stream, + mr); + } } else if (elem_type_id == cudf::type_id::STRUCT) { auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); if (gc_indices.empty()) { @@ -1546,6 +1849,7 @@ std::unique_ptr build_repeated_child_list_column( total_rep_count, stream, mr, + top_row_indices, depth + 1); } } else { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 47577c8bc5..7ca705d009 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1144,6 +1144,36 @@ __global__ void check_required_fields_kernel(field_location const* locations, int num_rows, int* error_flag); +inline void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + int* error_flag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0 || field_indices.empty()) { return; } + + bool has_required = false; + std::vector h_is_required(field_indices.size()); + for (size_t i = 0; i < field_indices.size(); ++i) { + h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; + has_required |= (h_is_required[i] != 0); + } + if (!has_required) { return; } + + rmm::device_uvector d_is_required(field_indices.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), + h_is_required.data(), + h_is_required.size() * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + check_required_fields_kernel<<>>( + locations, d_is_required.data(), static_cast(field_indices.size()), num_rows, error_flag); +} + __global__ void validate_enum_values_kernel(int32_t const* values, bool* valid, bool* row_has_invalid_enum, @@ -1247,6 +1277,7 @@ std::unique_ptr build_nested_struct_column( int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, int depth); std::unique_ptr build_repeated_child_list_column( @@ -1270,6 +1301,7 @@ std::unique_ptr build_repeated_child_list_column( rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, int depth); std::unique_ptr build_repeated_struct_column( diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index a1d241721a..8d901ad3e8 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1316,6 +1316,73 @@ void testRequiredFieldWithMultipleRows() { } } + @Test + void testRequiredNestedMessageMissing_Failfast() { + // message Outer { required Inner detail = 1; } + // message Inner { optional int32 id = 1; } + // Missing top-level required nested message should fail in FAILFAST mode. + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector ignored = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + true)) { + } + }); + } + } + + @Test + void testRequiredFieldInsideNestedMessageMissing_Failfast() { + // message Outer { optional Inner detail = 1; } + // message Inner { required int32 id = 1; optional string name = 2; } + // If detail is present but nested required id is missing, FAILFAST should throw. + Byte[] inner = concat( + box(tag(2, WT_LEN)), box(encodeVarint(4)), box("oops".getBytes())); + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(inner.length)), inner); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector ignored = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{WT_LEN, WT_VARINT, WT_LEN}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false}, + new boolean[]{false, true, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + true)) { + } + }); + } + } + // ============================================================================ // Default Value Tests (API accepts parameters, CUDA fill not yet implemented) // ============================================================================ @@ -2497,6 +2564,80 @@ void testEnumMissingFieldDoesNotNullRow() { } } + @Test + void testNestedEnumInvalidNullsGrandchildFieldInPermissiveMode() { + // message WithNestedEnum { + // optional int32 id = 1; + // optional Detail detail = 2; + // optional string name = 3; + // } + // message Detail { + // enum Status { UNKNOWN = 0; OK = 1; BAD = 2; } + // optional Status status = 1; + // optional int32 count = 2; + // } + // Invalid nested enum should null the whole row, including grandchild field detail.count. + Byte[] detail = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(999)), + box(tag(2, WT_VARINT)), box(encodeVarint(20))); + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(2, WT_LEN)), box(encodeVarint(detail.length)), detail, + box(tag(3, WT_LEN)), box(encodeVarint(3)), box("bad".getBytes())); + + byte[][][] enumNames = new byte[][][] { + null, + null, + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "OK".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAD".getBytes(java.nio.charset.StandardCharsets.UTF_8) + }, + null + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 2, 3, 1, 2}, + new int[]{-1, -1, -1, 1, 1}, + new int[]{0, 0, 0, 1, 1}, + new int[]{WT_VARINT, WT_LEN, WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.STRUCT.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_ENUM_STRING, + Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new long[]{0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false}, + new byte[][]{null, null, null, null, null}, + new int[][]{null, null, null, new int[]{0, 1, 2}, null}, + enumNames, + false); + ColumnVector detailCol = actual.getChildColumnView(1).copyToColumnVector(); + ColumnVector countCol = detailCol.getChildColumnView(1).copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostDetail = detailCol.copyToHost(); + HostColumnVector hostCount = countCol.copyToHost()) { + assertEquals(1, actual.getNullCount(), "Top-level row should be null"); + assertTrue(hostStruct.isNull(0), "Top-level struct should be null"); + assertEquals(1, detailCol.getNullCount(), "Nested struct child should be null after mask pushdown"); + assertTrue(hostDetail.isNull(0), "Nested struct child row should be null"); + assertEquals(1, countCol.getNullCount(), "Grandchild field should also be null"); + assertTrue(hostCount.isNull(0), "detail.count should be null when parent row is null"); + } + } + @Test void testEnumValidWithOtherFields() { // message Msg { Color color = 1; int32 count = 2; } @@ -2577,6 +2718,109 @@ void testRepeatedEnumAsString() { } } + @Test + void testRepeatedMessageChildEnumAsString() { + // message Item { optional Priority priority = 1; } + // message Outer { repeated Item items = 1; } + // enum Priority { UNKNOWN=0; FOO=1; BAR=2; } + Byte[] item0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // FOO + Byte[] item1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(2))); // BAR + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item0.length)), item0, + box(tag(1, WT_LEN)), box(encodeVarint(item1.length)), item1); + + byte[][][] enumNames = new byte[][][] { + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "FOO".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAR".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_ENUM_STRING}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + enumNames, + false); + ColumnVector items = actual.getChildColumnView(0).copyToColumnVector(); + ColumnVector itemStructs = items.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorities = itemStructs.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostPriorities = priorities.copyToHost()) { + assertEquals(2, priorities.getRowCount()); + assertEquals("FOO", hostPriorities.getJavaString(0)); + assertEquals("BAR", hostPriorities.getJavaString(1)); + } + } + + @Test + void testNestedRepeatedEnumAsString() { + // message Inner { repeated Priority priority = 1; } + // message Outer { optional Inner inner = 1; } + // enum Priority { UNKNOWN=0; FOO=1; BAR=2; } + byte[] packedPriorities = concatBytes(encodeVarint(0), encodeVarint(2), encodeVarint(1)); + Byte[] inner = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedPriorities.length)), + box(packedPriorities)); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(inner.length)), + inner); + + byte[][][] enumNames = new byte[][][] { + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "FOO".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAR".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_ENUM_STRING}, + new boolean[]{false, true}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + enumNames, + false); + ColumnVector innerStruct = actual.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorityList = innerStruct.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorities = priorityList.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostPriorities = priorities.copyToHost()) { + assertEquals(3, priorities.getRowCount()); + assertEquals("UNKNOWN", hostPriorities.getJavaString(0)); + assertEquals("BAR", hostPriorities.getJavaString(1)); + assertEquals("FOO", hostPriorities.getJavaString(2)); + } + } + // ============================================================================ // Edge case and boundary tests // ============================================================================ From 021372111dec3032aee4399cd88e64ffa418d50d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 14:46:16 +0800 Subject: [PATCH 070/107] comment address Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 34 ------------ .../rapids/jni/ProtobufSchemaDescriptor.java | 10 ++++ .../jni/ProtobufSchemaDescriptorTest.java | 54 +++++++++++++++++++ 3 files changed, 64 insertions(+), 34 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index ecbd27d86b..8c7bf14d57 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -343,10 +343,8 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column int child_idx, int num_child_fields, int total_count, - int32_t const* top_row_indices, std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -386,18 +384,6 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column static_cast(valid_enums.size()), total_count); - if (d_row_has_invalid_enum.size() > 0 && total_count > 0) { - auto const* elem_invalid = d_elem_has_invalid_enum.data(); - auto const* parent_rows = top_row_indices; - auto* row_invalid = d_row_has_invalid_enum.data(); - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - [elem_invalid, parent_rows, row_invalid] __device__(int idx) { - if (elem_invalid[idx]) { row_invalid[parent_rows[idx]] = true; } - }); - } - std::vector h_name_offsets(valid_enums.size() + 1, 0); int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { @@ -1042,10 +1028,8 @@ std::unique_ptr build_repeated_struct_column( ci, num_child_fields, total_count, - d_top_row_indices.data(), enum_valid_values[child_schema_idx], enum_names[child_schema_idx], - d_row_has_invalid_enum, d_error, stream, mr)); @@ -1702,24 +1686,6 @@ std::unique_ptr build_repeated_child_list_column( static_cast(valid_enums.size()), total_rep_count); - if (d_row_has_invalid_enum.size() > 0) { - auto const* rep_occs = d_rep_occs.data(); - auto const* parent_rows = top_row_indices; - auto const* elem_invalid = d_elem_has_invalid_enum.data(); - auto* row_invalid = d_row_has_invalid_enum.data(); - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_rep_count), - [rep_occs, parent_rows, elem_invalid, row_invalid] __device__(int idx) { - if (elem_invalid[idx]) { - auto const parent_row = rep_occs[idx].row_idx; - auto const top_row = - parent_rows != nullptr ? parent_rows[parent_row] : parent_row; - row_invalid[top_row] = true; - } - }); - } - std::vector h_name_offsets(valid_enums.size() + 1, 0); int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index fbaa44c86b..b59b98d037 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -16,6 +16,9 @@ package com.nvidia.spark.rapids.jni; +import java.util.HashSet; +import java.util.Set; + /** * Immutable descriptor for a flattened protobuf schema, grouping the parallel arrays * that describe field structure, types, defaults, and enum metadata. @@ -159,6 +162,7 @@ private static void validate( throw new IllegalArgumentException("All schema arrays must have the same length"); } + Set seenFieldNumbers = new HashSet<>(); for (int i = 0; i < n; i++) { if (fieldNumbers[i] <= 0 || fieldNumbers[i] > MAX_FIELD_NUMBER) { throw new IllegalArgumentException( @@ -183,6 +187,12 @@ private static void validate( ") must be parent depth (" + depthLevels[pi] + ") + 1"); } } + long fieldKey = (((long) pi) << 32) | (fieldNumbers[i] & 0xFFFFFFFFL); + if (!seenFieldNumbers.add(fieldKey)) { + throw new IllegalArgumentException( + "Duplicate field number " + fieldNumbers[i] + + " under parent index " + pi + " at schema index " + i); + } int wt = wireTypes[i]; if (wt != 0 && wt != 1 && wt != 2 && wt != 5) { throw new IllegalArgumentException( diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index f83f432d93..e4f5b11de6 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertThrows; public class ProtobufSchemaDescriptorTest { @@ -61,4 +62,57 @@ void testEnumStringRequiresEnumMetadata() { makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, null, new byte[][]{"A".getBytes(), "B".getBytes()})); } + + @Test + void testDuplicateFieldNumbersUnderSameParentRejected() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1, 7, 7}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{Protobuf.WT_LEN, Protobuf.WT_VARINT, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + new byte[][][]{null, null, null})); + } + + @Test + void testDuplicateFieldNumbersUnderDifferentParentsAllowed() { + assertDoesNotThrow(() -> + new ProtobufSchemaDescriptor( + new int[]{1, 2, 7, 7}, + new int[]{-1, -1, 0, 1}, + new int[]{0, 0, 1, 1}, + new int[]{Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_VARINT, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{ + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false}, + new boolean[]{false, false, false, false}, + new boolean[]{false, false, false, false}, + new long[]{0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false}, + new byte[][]{null, null, null, null}, + new int[][]{null, null, null, null}, + new byte[][][]{null, null, null, null})); + } } From 5357378e110f70c971ba3f589578da83ed52b678 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 15:23:52 +0800 Subject: [PATCH 071/107] comment address Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 3 +- src/main/cpp/src/protobuf_builders.cu | 75 ++++++----- src/main/cpp/src/protobuf_common.cuh | 32 +++-- src/main/cpp/src/protobuf_kernels.cu | 119 ++++++++++-------- .../com/nvidia/spark/rapids/jni/Protobuf.java | 1 - .../nvidia/spark/rapids/jni/ProtobufTest.java | 101 +++++++++++++++ 6 files changed, 223 insertions(+), 108 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index ecff747532..fc54da6b3d 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -240,6 +240,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int schema_idx = scalar_field_indices[i]; h_field_descs[i].field_number = schema[schema_idx].field_number; h_field_descs[i].expected_wire_type = schema[schema_idx].wire_type; + h_field_descs[i].is_repeated = false; } rmm::device_uvector d_field_descs(num_scalar, stream, mr); @@ -691,7 +692,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_scan_descs.data(), static_cast(h_scan_descs.size()), d_error.data(), - d_fn_to_scan.data(), + fn_to_scan_size > 0 ? d_fn_to_scan.data() : nullptr, fn_to_scan_size); } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 8c7bf14d57..d844cf2b7a 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -27,7 +27,7 @@ namespace spark_rapids_jni::protobuf_detail { */ inline std::unique_ptr build_repeated_msg_child_varlen_column( uint8_t const* message_data, - rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_row_offsets, rmm::device_uvector const& d_msg_locs, rmm::device_uvector const& d_child_locs, int child_idx, @@ -337,7 +337,7 @@ std::unique_ptr build_enum_string_column( inline std::unique_ptr build_repeated_msg_child_enum_string_column( uint8_t const* message_data, - rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_row_offsets, rmm::device_uvector const& d_msg_locs, rmm::device_uvector const& d_child_locs, int child_idx, @@ -361,8 +361,14 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column child_idx, num_child_fields}; extract_varint_kernel - <<>>( - message_data, loc_provider, total_count, enum_values.data(), valid.data(), d_error.data()); + <<>>(message_data, + loc_provider, + total_count, + enum_values.data(), + valid.data(), + d_error.data(), + false, + 0); rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), @@ -471,8 +477,15 @@ std::unique_ptr build_repeated_enum_string_column( rmm::device_uvector enum_ints(total_count, stream, mr); rmm::device_uvector elem_valid(total_count, stream, mr); RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; - extract_varint_kernel<<>>( - message_data, rep_loc, total_count, enum_ints.data(), elem_valid.data(), d_error.data()); + extract_varint_kernel + <<>>(message_data, + rep_loc, + total_count, + enum_ints.data(), + elem_valid.data(), + d_error.data(), + false, + 0); // 2. Build device-side enum lookup tables rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); @@ -862,6 +875,7 @@ std::unique_ptr build_repeated_struct_column( int child_schema_idx = child_field_indices[ci]; h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; h_child_descs[ci].expected_wire_type = h_device_schema[child_schema_idx].wire_type; + h_child_descs[ci].is_repeated = h_device_schema[child_schema_idx].is_repeated; } rmm::device_uvector d_child_descs(num_child_fields, stream, mr); CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), @@ -884,8 +898,7 @@ std::unique_ptr build_repeated_struct_column( // Create "virtual" parent locations from the occurrences using GPU kernel // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) rmm::device_uvector d_msg_locs(total_count, stream, mr); - rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); - rmm::device_uvector d_msg_row_offsets_size(total_count, stream, mr); + rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); { auto const occ_threads = THREADS_PER_BLOCK; auto const occ_blocks = (total_count + occ_threads - 1u) / occ_threads; @@ -898,11 +911,6 @@ std::unique_ptr build_repeated_struct_column( total_count, d_error_top.data()); } - thrust::transform(rmm::exec_policy(stream), - d_msg_row_offsets.data(), - d_msg_row_offsets.end(), - d_msg_row_offsets_size.data(), - [] __device__(int32_t v) { return static_cast(v); }); rmm::device_uvector d_top_row_indices(total_count, stream, mr); thrust::transform(rmm::exec_policy(stream), d_occurrences.data(), @@ -955,7 +963,7 @@ std::unique_ptr build_repeated_struct_column( if (child_is_repeated) { struct_children.push_back(build_repeated_child_list_column(message_data, message_data_size, - d_msg_row_offsets_size.data(), + d_msg_row_offsets.data(), 0, d_msg_locs.data(), total_count, @@ -1086,33 +1094,21 @@ std::unique_ptr build_repeated_struct_column( // Compute virtual parent locations for each occurrence's nested struct child rmm::device_uvector d_nested_locs(total_count, stream, mr); rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); - { - // Convert int32_t row offsets to cudf::size_type and compute nested struct locations - rmm::device_uvector d_nested_row_offsets_i32(total_count, stream, mr); - compute_nested_struct_locations_kernel<<>>( - d_child_locs.data(), - d_msg_locs.data(), - d_msg_row_offsets.data(), - ci, - num_child_fields, - d_nested_locs.data(), - d_nested_row_offsets_i32.data(), - total_count, - d_error_top.data()); - // Add base_offset back so build_nested_struct_column can subtract it - thrust::transform(rmm::exec_policy(stream), - d_nested_row_offsets_i32.data(), - d_nested_row_offsets_i32.end(), - d_nested_row_offsets.data(), - [base_offset] __device__(int32_t v) { - return static_cast(v) + base_offset; - }); - } + compute_nested_struct_locations_kernel<<>>( + d_child_locs.data(), + d_msg_locs.data(), + d_msg_row_offsets.data(), + ci, + num_child_fields, + d_nested_locs.data(), + d_nested_row_offsets.data(), + total_count, + d_error_top.data()); struct_children.push_back(build_nested_struct_column(message_data, message_data_size, d_nested_row_offsets.data(), - base_offset, + 0, d_nested_locs, grandchild_indices, schema, @@ -1222,6 +1218,7 @@ std::unique_ptr build_nested_struct_column( int child_idx = child_field_indices[i]; h_child_field_descs[i].field_number = schema[child_idx].field_number; h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; + h_child_field_descs[i].is_repeated = schema[child_idx].is_repeated; } rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); @@ -1662,7 +1659,9 @@ std::unique_ptr build_repeated_child_list_column( total_rep_count, enum_values.data(), valid.data(), - d_error.data()); + d_error.data(), + false, + 0); auto const& valid_enums = enum_valid_values[child_schema_idx]; auto const& enum_name_bytes = enum_names[child_schema_idx]; diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 7ca705d009..a9c3d96bc0 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -99,6 +99,7 @@ struct field_location { struct field_descriptor { int field_number; // Protobuf field number int expected_wire_type; // Expected wire type for this field + bool is_repeated; // Repeated children are scanned via count/scan kernels }; /** @@ -427,11 +428,12 @@ inline std::vector build_field_lookup_table(field_descriptor const* descs, * O(1) lookup of field_number -> field_index using a direct-mapped table. * Falls back to linear search when the table is empty (field numbers too large). */ -__device__ inline int lookup_field(int field_number, - int const* lookup_table, - int lookup_table_size, - field_descriptor const* field_descs, - int num_fields) +// Keep this definition in the header so all CUDA translation units can inline it. +__device__ __forceinline__ int lookup_field(int field_number, + int const* lookup_table, + int lookup_table_size, + field_descriptor const* field_descs, + int num_fields) { if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { return lookup_table[field_number]; @@ -507,10 +509,14 @@ struct NestedRepeatedLocationProvider { __device__ inline field_location get(int thread_idx, int32_t& data_offset) const { - auto occ = occurrences[thread_idx]; - auto ploc = parent_locations[occ.row_idx]; - data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; - return {occ.offset, occ.length}; + auto occ = occurrences[thread_idx]; + auto ploc = parent_locations[occ.row_idx]; + if (ploc.offset >= 0) { + data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; + return {occ.offset, occ.length}; + } + data_offset = 0; + return {-1, 0}; } }; @@ -1061,7 +1067,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, __global__ void scan_repeated_message_children_kernel(uint8_t const* message_data, cudf::size_type message_data_size, - int32_t const* msg_row_offsets, + cudf::size_type const* msg_row_offsets, field_location const* msg_locs, int num_occurrences, field_descriptor const* child_descs, @@ -1098,11 +1104,11 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, __global__ void compute_nested_struct_locations_kernel(field_location const* child_locs, field_location const* msg_locs, - int32_t const* msg_row_offsets, + cudf::size_type const* msg_row_offsets, int child_idx, int num_child_fields, field_location* nested_locs, - int32_t* nested_row_offsets, + cudf::size_type* nested_row_offsets, int total_count, int* error_flag); @@ -1128,7 +1134,7 @@ __global__ void compute_msg_locations_from_occurrences_kernel( cudf::size_type const* list_offsets, cudf::size_type base_offset, field_location* msg_locs, - int32_t* msg_row_offsets, + cudf::size_type* msg_row_offsets, int total_count, int* error_flag); diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 8b9d8f32e8..cfc332af79 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -467,7 +467,8 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co write_idx[f], error_flag); } - return true; + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; }; if (fn_to_desc_idx != nullptr && fn > 0 && fn < fn_to_desc_size) { @@ -542,6 +543,11 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, for (int f = 0; f < num_fields; f++) { if (field_descs[f].field_number == fn) { + if (field_descs[f].is_repeated) { + // Repeated children are handled by the dedicated count/scan path, not by + // the direct-child location scan used for scalar/nested singleton fields. + continue; + } if (wt != field_descs[f].expected_wire_type) { set_error_once(error_flag, ERR_WIRE_TYPE); return; @@ -590,7 +596,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, __global__ void scan_repeated_message_children_kernel( uint8_t const* message_data, cudf::size_type message_data_size, - int32_t const* msg_row_offsets, // Row offset for each occurrence + cudf::size_type const* msg_row_offsets, // Row offset for each occurrence field_location const* msg_locs, // Location of each message occurrence (offset within row, length) int num_occurrences, @@ -612,9 +618,9 @@ __global__ void scan_repeated_message_children_kernel( auto const& msg_loc = msg_locs[occ_idx]; if (msg_loc.offset < 0) return; - int32_t row_offset = msg_row_offsets[occ_idx]; - int64_t msg_start_off = static_cast(row_offset) + msg_loc.offset; - int64_t msg_end_off = msg_start_off + msg_loc.length; + cudf::size_type row_offset = msg_row_offsets[occ_idx]; + int64_t msg_start_off = static_cast(row_offset) + msg_loc.offset; + int64_t msg_end_off = msg_start_off + msg_loc.length; if (msg_start_off < 0 || msg_end_off > message_data_size) { set_error_once(error_flag, ERR_BOUNDS); return; @@ -632,53 +638,55 @@ __global__ void scan_repeated_message_children_kernel( int f = lookup_field(fn, child_lookup, child_lookup_size, child_descs, num_child_fields); if (f >= 0) { - bool is_packed = (wt == WT_LEN && child_descs[f].expected_wire_type != WT_LEN); - if (!is_packed && wt != child_descs[f].expected_wire_type) { + if (child_descs[f].is_repeated) { + // Repeated children are decoded by build_repeated_child_list_column via the + // nested repeated count/scan kernels, so do not record a singleton location here. + } else if (wt != child_descs[f].expected_wire_type) { set_error_once(error_flag, ERR_WIRE_TYPE); return; - } - - int data_offset = static_cast(cur - msg_start); - - if (wt == WT_LEN) { - uint64_t len; - int len_bytes; - if (!read_varint(cur, msg_end, len, len_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return; - } - if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - set_error_once(error_flag, ERR_OVERFLOW); - return; - } - child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, - static_cast(len)}; } else { - // For varint/fixed types, store offset and estimated length - int32_t data_length = 0; - if (wt == WT_VARINT) { - uint64_t dummy; - int vbytes; - if (!read_varint(cur, msg_end, dummy, vbytes)) { + int data_offset = static_cast(cur - msg_start); + + if (wt == WT_LEN) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { set_error_once(error_flag, ERR_VARINT); return; } - data_length = vbytes; - } else if (wt == WT_32BIT) { - if (msg_end - cur < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + set_error_once(error_flag, ERR_OVERFLOW); return; } - data_length = 4; - } else if (wt == WT_64BIT) { - if (msg_end - cur < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - return; + child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, + static_cast(len)}; + } else { + // For varint/fixed types, store offset and estimated length + int32_t data_length = 0; + if (wt == WT_VARINT) { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, msg_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + data_length = vbytes; + } else if (wt == WT_32BIT) { + if (msg_end - cur < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + data_length = 4; + } else if (wt == WT_64BIT) { + if (msg_end - cur < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + data_length = 8; } - data_length = 8; + child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; } - child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; } } @@ -846,13 +854,13 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, * This is a critical performance optimization. */ __global__ void compute_nested_struct_locations_kernel( - field_location const* child_locs, // Child field locations from parent scan - field_location const* msg_locs, // Parent message locations - int32_t const* msg_row_offsets, // Parent message row offsets - int child_idx, // Which child field is the nested struct - int num_child_fields, // Total number of child fields per occurrence - field_location* nested_locs, // Output: nested struct locations - int32_t* nested_row_offsets, // Output: nested struct row offsets + field_location const* child_locs, // Child field locations from parent scan + field_location const* msg_locs, // Parent message locations + cudf::size_type const* msg_row_offsets, // Parent message row offsets + int child_idx, // Which child field is the nested struct + int num_child_fields, // Total number of child fields per occurrence + field_location* nested_locs, // Output: nested struct locations + cudf::size_type* nested_row_offsets, // Output: nested struct row offsets int total_count, int* error_flag) { @@ -861,13 +869,14 @@ __global__ void compute_nested_struct_locations_kernel( nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; - if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + if (sum < std::numeric_limits::min() || + sum > std::numeric_limits::max()) { nested_locs[idx] = {-1, 0}; nested_row_offsets[idx] = 0; set_error_once(error_flag, ERR_OVERFLOW); return; } - nested_row_offsets[idx] = static_cast(sum); + nested_row_offsets[idx] = static_cast(sum); } /** @@ -952,7 +961,7 @@ __global__ void compute_msg_locations_from_occurrences_kernel( cudf::size_type const* list_offsets, // List offsets for rows cudf::size_type base_offset, // Base offset to subtract field_location* msg_locs, // Output: message locations - int32_t* msg_row_offsets, // Output: message row offsets + cudf::size_type* msg_row_offsets, // Output: message row offsets int total_count, int* error_flag) { @@ -961,14 +970,14 @@ __global__ void compute_msg_locations_from_occurrences_kernel( auto const& occ = occurrences[idx]; auto row_offset = static_cast(list_offsets[occ.row_idx]) - base_offset; - if (row_offset < std::numeric_limits::min() || - row_offset > std::numeric_limits::max()) { + if (row_offset < std::numeric_limits::min() || + row_offset > std::numeric_limits::max()) { msg_row_offsets[idx] = 0; msg_locs[idx] = {-1, 0}; set_error_once(error_flag, ERR_OVERFLOW); return; } - msg_row_offsets[idx] = static_cast(row_offset); + msg_row_offsets[idx] = static_cast(row_offset); msg_locs[idx] = {occ.offset, occ.length}; } diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 653ede0c33..1660457d79 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -60,7 +60,6 @@ public class Protobuf { public static final int WT_64BIT = 1; public static final int WT_LEN = 2; public static final int WT_32BIT = 5; - private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; /** * Decode protobuf messages into a STRUCT column. diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 8d901ad3e8..7696997189 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -2038,6 +2038,107 @@ void testPackedRepeatedInsideNestedMessage() { } } + @Test + void testPackedRepeatedChildInsideRepeatedMessage() { + // message Item { repeated int32 ids = 1 [packed=true]; optional int32 score = 2; } + // message Outer { repeated Item items = 1; } + byte[] item0Ids = concatBytes(encodeVarint(10), encodeVarint(20)); + Byte[] item0 = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(item0Ids.length)), + box(item0Ids), + box(tag(2, WT_VARINT)), + box(encodeVarint(7))); + byte[] item1Ids = concatBytes(encodeVarint(30)); + Byte[] item1 = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(item1Ids.length)), + box(item1Ids), + box(tag(2, WT_VARINT)), + box(encodeVarint(9))); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(item0.length)), + item0, + box(tag(1, WT_LEN)), + box(encodeVarint(item1.length)), + item1); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedItems = ColumnVector.fromLists( + new ListType(true, + new StructType(true, + new ListType(true, new BasicType(true, DType.INT32)), + new BasicType(true, DType.INT32))), + Arrays.asList( + new StructData(Arrays.asList(10, 20), 7), + new StructData(Arrays.asList(30), 9))); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedItems); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{ + DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId() + }, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, true, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testPermissiveRepeatedWrongWireTypeDoesNotCorruptFollowingRow() { + // message Msg { repeated int32 ids = 1; } + // Row 0 has one valid element, then a malformed fixed32 occurrence for the same field, + // then another valid varint that must be ignored once the row is marked malformed. + // Row 1 must keep its own slot and not be overwritten by row 0's trailing occurrence. + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(1, WT_32BIT)), box(encodeFixed32(77)), + box(tag(1, WT_VARINT)), box(encodeVarint(2))); + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, row1}).build(); + ColumnVector expectedIds = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.INT32)), + Arrays.asList(1), + Arrays.asList(100)); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedIds); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + @Test void testRepeatedUint32() { Byte[] row = concat( From c4b1507b9adf85f9f5bcefa4769bd82e4685ec23 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 17:10:07 +0800 Subject: [PATCH 072/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 384 ++++++------------ src/main/cpp/src/protobuf_common.cuh | 107 ++++- .../nvidia/spark/rapids/jni/ProtobufTest.java | 39 ++ 3 files changed, 250 insertions(+), 280 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index d844cf2b7a..f73f3ce5bb 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -243,19 +243,18 @@ std::unique_ptr make_empty_list_column(std::unique_ptr build_enum_string_column( - rmm::device_uvector& enum_values, - rmm::device_uvector& valid, +struct enum_string_lookup_tables { + rmm::device_uvector d_valid_enums; + rmm::device_uvector d_name_offsets; + rmm::device_uvector d_name_chars; +}; + +inline enum_string_lookup_tables make_enum_string_lookup_tables( std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, - int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1u) / threads); - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), valid_enums.data(), @@ -263,14 +262,6 @@ std::unique_ptr build_enum_string_column( cudaMemcpyHostToDevice, stream.value())); - validate_enum_values_kernel<<>>( - enum_values.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_rows); - std::vector h_name_offsets(valid_enums.size() + 1, 0); int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { @@ -279,6 +270,7 @@ std::unique_ptr build_enum_string_column( "Enum name data exceeds 2 GB limit"); h_name_offsets[k + 1] = static_cast(total_name_chars); } + std::vector h_name_chars(total_name_chars); int32_t cursor = 0; for (auto const& name : enum_name_bytes) { @@ -294,6 +286,7 @@ std::unique_ptr build_enum_string_column( h_name_offsets.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + rmm::device_uvector d_name_chars(total_name_chars, stream, mr); if (total_name_chars > 0) { CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), @@ -303,13 +296,27 @@ std::unique_ptr build_enum_string_column( stream.value())); } + return {std::move(d_valid_enums), std::move(d_name_offsets), std::move(d_name_chars)}; +} + +inline std::unique_ptr build_enum_string_values_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + enum_string_lookup_tables const& lookup, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + rmm::device_uvector lengths(num_rows, stream, mr); compute_enum_string_lengths_kernel<<>>( enum_values.data(), valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), + lookup.d_valid_enums.data(), + lookup.d_name_offsets.data(), + static_cast(lookup.d_valid_enums.size()), lengths.data(), num_rows); @@ -321,10 +328,10 @@ std::unique_ptr build_enum_string_column( copy_enum_string_chars_kernel<<>>( enum_values.data(), valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), + lookup.d_valid_enums.data(), + lookup.d_name_offsets.data(), + lookup.d_name_chars.data(), + static_cast(lookup.d_valid_enums.size()), offsets_col->view().data(), chars.data(), num_rows); @@ -335,6 +342,38 @@ std::unique_ptr build_enum_string_column( num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); } +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); + rmm::device_uvector d_item_has_invalid_enum(num_rows, stream, mr); + thrust::fill(rmm::exec_policy(stream), + d_item_has_invalid_enum.begin(), + d_item_has_invalid_enum.end(), + false); + + validate_enum_values_kernel<<>>( + enum_values.data(), + valid.data(), + d_item_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows); + propagate_invalid_enum_flags_to_rows( + d_item_has_invalid_enum, d_row_has_invalid_enum, num_rows, top_row_indices, stream, mr); + return build_enum_string_values_column(enum_values, valid, lookup, num_rows, stream, mr); +} + inline std::unique_ptr build_repeated_msg_child_enum_string_column( uint8_t const* message_data, rmm::device_uvector const& d_msg_row_offsets, @@ -345,12 +384,15 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column int total_count, std::vector const& valid_enums, std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_has_invalid_enum, + int32_t const* top_row_indices, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { auto const threads = THREADS_PER_BLOCK; auto const blocks = static_cast((total_count + threads - 1u) / threads); + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); rmm::device_uvector enum_values(total_count, stream, mr); rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); @@ -370,13 +412,6 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column false, 0); - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); thrust::fill(rmm::exec_policy(stream), d_elem_has_invalid_enum.begin(), @@ -386,72 +421,12 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column enum_values.data(), valid.data(), d_elem_has_invalid_enum.data(), - d_valid_enums.data(), + lookup.d_valid_enums.data(), static_cast(valid_enums.size()), total_count); - - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int64_t total_name_chars = 0; - for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), - "Enum name data exceeds 2 GB limit"); - h_name_offsets[k + 1] = static_cast(total_name_chars); - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& name : enum_name_bytes) { - if (!name.empty()) { - std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); - cursor += static_cast(name.size()); - } - } - - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(total_count, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_values.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - lengths.data(), - total_count); - - auto [offsets_col, total_chars] = - cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - enum_values.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - offsets_col->view().data(), - chars.data(), - total_count); - } - - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - return cudf::make_strings_column( - total_count, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + propagate_invalid_enum_flags_to_rows( + d_elem_has_invalid_enum, d_row_has_invalid_enum, total_count, top_row_indices, stream, mr); + return build_enum_string_values_column(enum_values, valid, lookup, total_count, stream, mr); } std::unique_ptr build_repeated_enum_string_column( @@ -472,6 +447,7 @@ std::unique_ptr build_repeated_enum_string_column( { auto const rep_blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); @@ -487,46 +463,7 @@ std::unique_ptr build_repeated_enum_string_column( false, 0); - // 2. Build device-side enum lookup tables - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int64_t total_name_chars = 0; - for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), - "Enum name data exceeds 2 GB limit"); - h_name_offsets[k + 1] = static_cast(total_name_chars); - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& nm : enum_name_bytes) { - if (!nm.empty()) { - std::copy(nm.data(), nm.data() + nm.size(), h_name_chars.data() + cursor); - cursor += static_cast(nm.size()); - } - } - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - // 3. Validate enum values — mark invalid as false in elem_valid + // 2. Validate enum values — mark invalid as false in elem_valid // (elem_valid was already populated by extract_varint_kernel: true for success, false for // failure) rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); @@ -538,62 +475,25 @@ std::unique_ptr build_repeated_enum_string_column( enum_ints.data(), elem_valid.data(), d_elem_has_invalid_enum.data(), - d_valid_enums.data(), + lookup.d_valid_enums.data(), static_cast(valid_enums.size()), total_count); - // 3c. Propagate per-element invalid enum flags to per-row flags for struct null mask. - // Spark CPU nullifies the entire struct row when any repeated enum element is invalid. - if (d_row_has_invalid_enum.size() > 0 && total_count > 0) { - auto const* occs = d_occurrences.data(); - auto const* elem_invalid = d_elem_has_invalid_enum.data(); - auto* row_invalid = d_row_has_invalid_enum.data(); - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(total_count), - [occs, elem_invalid, row_invalid] __device__(int idx) { - if (elem_invalid[idx]) { - // Safe: all threads write the same value (true). On sm_70+ byte stores - // are independently addressable and do not tear neighboring bytes. - row_invalid[occs[idx].row_idx] = true; - } - }); - } - - // 4. Compute per-element string lengths - rmm::device_uvector elem_lengths(total_count, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_ints.data(), - elem_valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - elem_lengths.data(), - total_count); - - // 5. Build string offsets - auto [str_offs_col, total_chars] = cudf::strings::detail::make_offsets_child_column( - elem_lengths.begin(), elem_lengths.end(), stream, mr); - - // 6. Copy string chars - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - enum_ints.data(), - elem_valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - str_offs_col->view().data(), - chars.data(), - total_count); - } + rmm::device_uvector d_top_row_indices(total_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_occurrences.begin(), + d_occurrences.end(), + d_top_row_indices.begin(), + [] __device__(repeated_occurrence const& occ) { return occ.row_idx; }); + propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, + d_row_has_invalid_enum, + total_count, + d_top_row_indices.data(), + stream, + mr); - // 7. Assemble strings child column with null mask from elem_valid - auto [child_mask, child_null_count] = make_null_mask_from_valid(elem_valid, stream, mr); - auto child_col = cudf::make_strings_column( - total_count, std::move(str_offs_col), chars.release(), child_null_count, std::move(child_mask)); + auto child_col = + build_enum_string_values_column(enum_ints, elem_valid, lookup, total_count, stream, mr); // 8. Build LIST column with list offsets from per-row counts rmm::device_uvector lo(num_rows + 1, stream, mr); @@ -1019,7 +919,8 @@ std::unique_ptr build_repeated_struct_column( d_row_has_invalid_enum, d_error, stream, - mr)); + mr, + d_top_row_indices.data())); break; } case cudf::type_id::STRING: { @@ -1038,6 +939,8 @@ std::unique_ptr build_repeated_struct_column( total_count, enum_valid_values[child_schema_idx], enum_names[child_schema_idx], + d_row_has_invalid_enum, + d_top_row_indices.data(), d_error, stream, mr)); @@ -1313,7 +1216,8 @@ std::unique_ptr build_nested_struct_column( d_row_has_invalid_enum, d_error, stream, - mr)); + mr, + top_row_indices)); break; } case cudf::type_id::STRING: { @@ -1349,7 +1253,8 @@ std::unique_ptr build_nested_struct_column( d_row_has_invalid_enum, num_rows, stream, - mr)); + mr, + top_row_indices)); } else { { int err_val = ERR_MISSING_ENUM_META; @@ -1616,6 +1521,16 @@ std::unique_ptr build_repeated_child_list_column( d_rep_occs.data(), d_error.data()); + rmm::device_uvector d_rep_top_row_indices(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy(stream), + d_rep_occs.begin(), + d_rep_occs.end(), + d_rep_top_row_indices.begin(), + [top_row_indices] __device__(repeated_occurrence const& occ) { + return top_row_indices != nullptr ? top_row_indices[occ.row_idx] + : occ.row_idx; + }); + std::unique_ptr child_values; auto const rep_blocks = static_cast((total_rep_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); @@ -1643,7 +1558,8 @@ std::unique_ptr build_repeated_child_list_column( d_row_has_invalid_enum, d_error, stream, - mr); + mr, + d_rep_top_row_indices.data()); } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { if (elem_type_id == cudf::type_id::STRING && schema[child_schema_idx].encoding == spark_rapids_jni::ENC_ENUM_STRING) { @@ -1651,6 +1567,8 @@ std::unique_ptr build_repeated_child_list_column( child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && enum_valid_values[child_schema_idx].size() == enum_names[child_schema_idx].size()) { + auto lookup = make_enum_string_lookup_tables( + enum_valid_values[child_schema_idx], enum_names[child_schema_idx], stream, mr); rmm::device_uvector enum_values(total_rep_count, stream, mr); rmm::device_uvector valid((total_rep_count > 0 ? total_rep_count : 1), stream, mr); extract_varint_kernel @@ -1663,15 +1581,6 @@ std::unique_ptr build_repeated_child_list_column( false, 0); - auto const& valid_enums = enum_valid_values[child_schema_idx]; - auto const& enum_name_bytes = enum_names[child_schema_idx]; - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_elem_has_invalid_enum(total_rep_count, stream, mr); thrust::fill(rmm::exec_policy(stream), d_elem_has_invalid_enum.begin(), @@ -1681,72 +1590,17 @@ std::unique_ptr build_repeated_child_list_column( enum_values.data(), valid.data(), d_elem_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - total_rep_count); - - std::vector h_name_offsets(valid_enums.size() + 1, 0); - int64_t total_name_chars = 0; - for (size_t k = 0; k < enum_name_bytes.size(); ++k) { - total_name_chars += static_cast(enum_name_bytes[k].size()); - CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), - "Enum name data exceeds 2 GB limit"); - h_name_offsets[k + 1] = static_cast(total_name_chars); - } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; - for (auto const& name : enum_name_bytes) { - if (!name.empty()) { - std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); - cursor += static_cast(name.size()); - } - } - - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } - - rmm::device_uvector lengths(total_rep_count, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_values.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - static_cast(valid_enums.size()), - lengths.data(), + lookup.d_valid_enums.data(), + static_cast(lookup.d_valid_enums.size()), total_rep_count); - - auto [offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( - lengths.begin(), lengths.end(), stream, mr); - - rmm::device_uvector chars(total_chars, stream, mr); - if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - enum_values.data(), - valid.data(), - d_valid_enums.data(), - d_name_offsets.data(), - d_name_chars.data(), - static_cast(valid_enums.size()), - offsets_col->view().data(), - chars.data(), - total_rep_count); - } - - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - child_values = cudf::make_strings_column( - total_rep_count, std::move(offsets_col), chars.release(), null_count, std::move(mask)); + propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, + d_row_has_invalid_enum, + total_rep_count, + d_rep_top_row_indices.data(), + stream, + mr); + child_values = + build_enum_string_values_column(enum_values, valid, lookup, total_rep_count, stream, mr); } else { int err_val = ERR_MISSING_ENUM_META; CUDF_CUDA_TRY(cudaMemcpyAsync( @@ -1814,7 +1668,7 @@ std::unique_ptr build_repeated_child_list_column( total_rep_count, stream, mr, - top_row_indices, + d_rep_top_row_indices.data(), depth + 1); } } else { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index a9c3d96bc0..da2109f180 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -39,8 +39,11 @@ #include #include #include +#include #include +#include #include +#include #include #include @@ -1205,6 +1208,83 @@ __global__ void copy_enum_string_chars_kernel(int32_t const* values, char* out_chars, int num_rows); +inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_items == 0 || row_invalid.size() == 0) { return; } + + if (top_row_indices == nullptr) { + CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), + "enum invalid-row propagation exceeded row buffer"); + thrust::transform(rmm::exec_policy(stream), + row_invalid.begin(), + row_invalid.begin() + num_items, + item_invalid.begin(), + row_invalid.begin(), + [] __device__(bool row_is_invalid, bool item_is_invalid) { + return row_is_invalid || item_is_invalid; + }); + return; + } + + rmm::device_uvector invalid_rows(num_items, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_items), + invalid_rows.begin(), + [item_invalid = item_invalid.data(), top_row_indices] __device__(int idx) { + return item_invalid[idx] ? top_row_indices[idx] : -1; + }); + + auto valid_end = + thrust::remove(rmm::exec_policy(stream), invalid_rows.begin(), invalid_rows.end(), -1); + thrust::sort(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); + auto unique_end = thrust::unique(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); + thrust::for_each(rmm::exec_policy(stream), + invalid_rows.begin(), + unique_end, + [row_invalid = row_invalid.data()] __device__(int32_t row_idx) { + row_invalid[row_idx] = true; + }); +} + +inline void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_items == 0 || valid_enums.empty()) { return; } + + auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector item_invalid(num_items, stream, mr); + thrust::fill(rmm::exec_policy(stream), item_invalid.begin(), item_invalid.end(), false); + validate_enum_values_kernel<<>>( + values.data(), + valid.data(), + item_invalid.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_items); + + propagate_invalid_enum_flags_to_rows( + item_invalid, row_invalid, num_items, top_row_indices, stream, mr); +} + // ============================================================================ // Forward declarations of builder/utility functions // ============================================================================ @@ -1228,7 +1308,8 @@ std::unique_ptr build_enum_string_column( rmm::device_uvector& d_row_has_invalid_enum, int num_rows, rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr); // Complex builder forward declarations std::unique_ptr build_repeated_enum_string_column( @@ -1422,7 +1503,8 @@ inline std::unique_ptr extract_typed_column( rmm::device_uvector& d_row_has_invalid_enum, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr) { switch (dt.id()) { case cudf::type_id::BOOL8: { @@ -1463,19 +1545,14 @@ inline std::unique_ptr extract_typed_column( if (schema_idx < static_cast(enum_valid_values.size())) { auto const& valid_enums = enum_valid_values[schema_idx]; if (!valid_enums.empty()) { - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - validate_enum_values_kernel<<>>( - out.data(), - valid.data(), - d_row_has_invalid_enum.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_items); + validate_enum_and_propagate_rows(out, + valid, + valid_enums, + d_row_has_invalid_enum, + num_items, + top_row_indices, + stream, + mr); } } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 7696997189..887e03d76d 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -2643,6 +2643,45 @@ void testEnumWithOtherFields_NullsEntireRow() { } } + @Test + void testRepeatedStructEnumInvalidNullsCorrectTopLevelRow() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Item { Color color = 1; } + // message Msg { repeated Item items = 1; } + Byte[] item00 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); // valid + Byte[] item01 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // invalid + Byte[] row0 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item00.length)), item00, + box(tag(1, WT_LEN)), box(encodeVarint(item01.length)), item01); + Byte[] item10 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // valid + Byte[] row1 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item10.length)), item10); + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + assertEquals(1, actualStruct.getNullCount(), "Exactly one top-level row should be null"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null because one repeated child enum is invalid"); + assertFalse(hostStruct.isNull(1), "Row 1 should remain valid"); + } + } + @Test void testEnumMissingFieldDoesNotNullRow() { // Missing enum field should return null for the field, but NOT null the entire row From 3852f5349463f499454e147490672f321342479d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 12 Mar 2026 17:54:23 +0800 Subject: [PATCH 073/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 19 +++- src/main/cpp/src/protobuf.hpp | 100 +++++++++++++++++- src/main/cpp/src/protobuf_builders.cu | 4 + .../rapids/jni/ProtobufSchemaDescriptor.java | 59 +++++++++++ .../jni/ProtobufSchemaDescriptorTest.java | 62 +++++++++++ 5 files changed, 238 insertions(+), 6 deletions(-) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 334aac1bdf..511e47d9cd 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -28,7 +28,9 @@ #include #include #include +#include #include +#include #include namespace { @@ -162,6 +164,14 @@ using nfd = spark_rapids_jni::nested_field_descriptor; using pb_field_location = spark_rapids_jni::protobuf_detail::field_location; using pb_repeated_occurrence = spark_rapids_jni::protobuf_detail::repeated_occurrence; +inline int32_t checked_size_to_i32(size_t value, char const* what) +{ + if (value > static_cast(std::numeric_limits::max())) { + throw std::overflow_error(std::string("benchmark protobuf size exceeds int32_t for ") + what); + } + return static_cast(value); +} + void encode_string_field_record(std::vector& buf, int field_number, std::string const& s, @@ -170,9 +180,10 @@ void encode_string_field_record(std::vector& buf, { encode_tag(buf, field_number, /*WT_LEN=*/2); encode_varint(buf, s.size()); - auto const data_offset = static_cast(buf.size()); + auto const data_offset = checked_size_to_i32(buf.size(), "string field data offset"); buf.insert(buf.end(), s.begin(), s.end()); - out_occurrences.push_back({row_idx, data_offset, static_cast(s.size())}); + out_occurrences.push_back( + {row_idx, data_offset, checked_size_to_i32(s.size(), "string field length")}); } // Case 1: Flat scalars only — many top-level scalar fields. @@ -1055,6 +1066,10 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) std::vector h_rep_indices(num_repeated_children); for (int i = 0; i < num_repeated_children; i++) { + CUDF_EXPECTS(h_schema[i].is_repeated, + "count_repeated_in_nested_kernel benchmark expects repeated child fields"); + CUDF_EXPECTS(h_schema[i].depth == 0, + "count_repeated_in_nested_kernel benchmark expects pre-filtered child depth 0"); h_rep_indices[i] = i; } rmm::device_uvector d_rep_indices(num_repeated_children, stream, mr); diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index c5c53d019f..036735a8fb 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -30,10 +30,17 @@ namespace spark_rapids_jni { // Encoding constants -constexpr int ENC_DEFAULT = 0; -constexpr int ENC_FIXED = 1; -constexpr int ENC_ZIGZAG = 2; -constexpr int ENC_ENUM_STRING = 3; +constexpr int ENC_DEFAULT = 0; +constexpr int ENC_FIXED = 1; +constexpr int ENC_ZIGZAG = 2; +constexpr int ENC_ENUM_STRING = 3; +constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; + +// Wire type constants +constexpr int WT_VARINT = 0; +constexpr int WT_64BIT = 1; +constexpr int WT_LEN = 2; +constexpr int WT_32BIT = 5; // Maximum nesting depth for nested messages constexpr int MAX_NESTING_DEPTH = 10; @@ -80,6 +87,42 @@ struct ProtobufFieldMetaView { std::vector> const& enum_names; }; +inline bool is_encoding_compatible(nested_field_descriptor const& field, + cudf::data_type const& type) +{ + switch (field.encoding) { + case ENC_DEFAULT: + switch (type.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: return field.wire_type == WT_VARINT; + case cudf::type_id::FLOAT32: return field.wire_type == WT_32BIT; + case cudf::type_id::FLOAT64: return field.wire_type == WT_64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: + case cudf::type_id::STRUCT: return field.wire_type == WT_LEN; + default: return false; + } + case ENC_FIXED: + switch (type.id()) { + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::FLOAT32: return field.wire_type == WT_32BIT; + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT64: return field.wire_type == WT_64BIT; + default: return false; + } + case ENC_ZIGZAG: + return field.wire_type == WT_VARINT && + (type.id() == cudf::type_id::INT32 || type.id() == cudf::type_id::INT64); + case ENC_ENUM_STRING: return field.wire_type == WT_VARINT && type.id() == cudf::type_id::STRING; + default: return false; + } +} + inline void validate_decode_context(ProtobufDecodeContext const& context) { auto const num_fields = context.schema.size(); @@ -105,6 +148,55 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) for (size_t i = 0; i < num_fields; ++i) { auto const& field = context.schema[i]; + auto const& type = context.schema_output_types[i]; + if (type.id() != field.output_type) { + throw std::invalid_argument( + "protobuf decode context: schema_output_types id mismatch at field " + std::to_string(i)); + } + if (field.field_number <= 0 || field.field_number > MAX_FIELD_NUMBER) { + throw std::invalid_argument("protobuf decode context: invalid field number at field " + + std::to_string(i)); + } + if (field.parent_idx < -1 || field.parent_idx >= static_cast(i)) { + throw std::invalid_argument("protobuf decode context: invalid parent index at field " + + std::to_string(i)); + } + if (field.parent_idx == -1) { + if (field.depth != 0) { + throw std::invalid_argument( + "protobuf decode context: top-level field must have depth 0 at field " + + std::to_string(i)); + } + } else { + auto const& parent = context.schema[field.parent_idx]; + if (field.depth != parent.depth + 1) { + throw std::invalid_argument("protobuf decode context: child depth mismatch at field " + + std::to_string(i)); + } + if (context.schema_output_types[field.parent_idx].id() != cudf::type_id::STRUCT) { + throw std::invalid_argument("protobuf decode context: parent must be STRUCT at field " + + std::to_string(i)); + } + } + if (!(field.wire_type == WT_VARINT || field.wire_type == WT_64BIT || + field.wire_type == WT_LEN || field.wire_type == WT_32BIT)) { + throw std::invalid_argument("protobuf decode context: invalid wire type at field " + + std::to_string(i)); + } + if (field.encoding < ENC_DEFAULT || field.encoding > ENC_ENUM_STRING) { + throw std::invalid_argument("protobuf decode context: invalid encoding at field " + + std::to_string(i)); + } + if (field.is_repeated && field.has_default_value) { + throw std::invalid_argument( + "protobuf decode context: repeated field cannot carry default value at field " + + std::to_string(i)); + } + if (!is_encoding_compatible(field, type)) { + throw std::invalid_argument( + "protobuf decode context: incompatible wire type/encoding/output type at field " + + std::to_string(i)); + } if (field.encoding == ENC_ENUM_STRING && context.enum_valid_values[i].size() != context.enum_names[i].size()) { throw std::invalid_argument( diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index f73f3ce5bb..7bb42e19d3 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -1449,6 +1449,10 @@ std::unique_ptr build_repeated_child_list_column( rep_desc.encoding = 0; rep_desc.is_required = false; rep_desc.has_default_value = false; + CUDF_EXPECTS(schema[child_schema_idx].is_repeated, + "count_repeated_in_nested_kernel launch requires repeated child schema"); + CUDF_EXPECTS(rep_desc.depth == 0, + "count_repeated_in_nested_kernel launch requires pre-filtered local depth 0"); std::vector h_rep_schema = {rep_desc}; rmm::device_uvector d_rep_schema(1, stream, mr); diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index b59b98d037..6bedc210ba 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -33,6 +33,16 @@ public final class ProtobufSchemaDescriptor implements java.io.Serializable { private static final long serialVersionUID = 1L; private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; + private static final int STRUCT_TYPE_ID = ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(); + private static final int STRING_TYPE_ID = ai.rapids.cudf.DType.STRING.getTypeId().getNativeId(); + private static final int LIST_TYPE_ID = ai.rapids.cudf.DType.LIST.getTypeId().getNativeId(); + private static final int BOOL8_TYPE_ID = ai.rapids.cudf.DType.BOOL8.getTypeId().getNativeId(); + private static final int INT32_TYPE_ID = ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(); + private static final int UINT32_TYPE_ID = ai.rapids.cudf.DType.UINT32.getTypeId().getNativeId(); + private static final int INT64_TYPE_ID = ai.rapids.cudf.DType.INT64.getTypeId().getNativeId(); + private static final int UINT64_TYPE_ID = ai.rapids.cudf.DType.UINT64.getTypeId().getNativeId(); + private static final int FLOAT32_TYPE_ID = ai.rapids.cudf.DType.FLOAT32.getTypeId().getNativeId(); + private static final int FLOAT64_TYPE_ID = ai.rapids.cudf.DType.FLOAT64.getTypeId().getNativeId(); final int[] fieldNumbers; final int[] parentIndices; @@ -181,6 +191,11 @@ private static void validate( "Top-level field at index " + i + " must have depth 0, got " + depthLevels[i]); } } else { + if (outputTypeIds[pi] != STRUCT_TYPE_ID) { + throw new IllegalArgumentException( + "Parent at index " + pi + " for field " + i + " must be STRUCT, got type id " + + outputTypeIds[pi]); + } if (depthLevels[i] != depthLevels[pi] + 1) { throw new IllegalArgumentException( "Field at index " + i + " depth (" + depthLevels[i] + @@ -204,6 +219,11 @@ private static void validate( throw new IllegalArgumentException( "Invalid encoding at index " + i + ": " + enc); } + if (!isEncodingCompatible(wt, outputTypeIds[i], enc)) { + throw new IllegalArgumentException( + "Incompatible wire type / output type / encoding at index " + i + + ": wireType=" + wt + ", outputTypeId=" + outputTypeIds[i] + ", encoding=" + enc); + } if (isRepeated[i] && hasDefaultValue[i]) { throw new IllegalArgumentException( "Repeated field at index " + i + " cannot carry a default value"); @@ -236,4 +256,43 @@ private static void validate( } } } + + private static boolean isEncodingCompatible(int wireType, int outputTypeId, int encoding) { + switch (encoding) { + case Protobuf.ENC_DEFAULT: + if (outputTypeId == BOOL8_TYPE_ID || outputTypeId == INT32_TYPE_ID || + outputTypeId == UINT32_TYPE_ID || outputTypeId == INT64_TYPE_ID || + outputTypeId == UINT64_TYPE_ID) { + return wireType == Protobuf.WT_VARINT; + } + if (outputTypeId == FLOAT32_TYPE_ID) { + return wireType == Protobuf.WT_32BIT; + } + if (outputTypeId == FLOAT64_TYPE_ID) { + return wireType == Protobuf.WT_64BIT; + } + if (outputTypeId == STRING_TYPE_ID || outputTypeId == LIST_TYPE_ID || + outputTypeId == STRUCT_TYPE_ID) { + return wireType == Protobuf.WT_LEN; + } + return false; + case Protobuf.ENC_FIXED: + if (outputTypeId == INT32_TYPE_ID || outputTypeId == UINT32_TYPE_ID || + outputTypeId == FLOAT32_TYPE_ID) { + return wireType == Protobuf.WT_32BIT; + } + if (outputTypeId == INT64_TYPE_ID || outputTypeId == UINT64_TYPE_ID || + outputTypeId == FLOAT64_TYPE_ID) { + return wireType == Protobuf.WT_64BIT; + } + return false; + case Protobuf.ENC_ZIGZAG: + return wireType == Protobuf.WT_VARINT && + (outputTypeId == INT32_TYPE_ID || outputTypeId == INT64_TYPE_ID); + case Protobuf.ENC_ENUM_STRING: + return wireType == Protobuf.WT_VARINT && outputTypeId == STRING_TYPE_ID; + default: + return false; + } + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index e4f5b11de6..85a5daa556 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -115,4 +115,66 @@ void testDuplicateFieldNumbersUnderDifferentParentsAllowed() { new int[][]{null, null, null, null}, new byte[][][]{null, null, null, null})); } + + @Test + void testChildParentMustBeStruct() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1, 2}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{Protobuf.WT_VARINT, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + new byte[][][]{null, null})); + } + + @Test + void testEncodingCompatibilityValidation() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_32BIT}, + new int[]{ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_LEN}, + new int[]{ai.rapids.cudf.DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ENUM_STRING}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{{0, 1}}, + new byte[][][]{new byte[][]{"A".getBytes(), "B".getBytes()}})); + } } From ce29c1197d561597a2163c9ec50768882849016f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 09:34:28 +0800 Subject: [PATCH 074/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 12 +- src/main/cpp/src/protobuf_builders.cu | 21 ++- src/main/cpp/src/protobuf_common.cuh | 37 ++++- src/main/cpp/src/protobuf_kernels.cu | 130 +++++++++++++----- .../nvidia/spark/rapids/jni/ProtobufTest.java | 17 +++ 5 files changed, 172 insertions(+), 45 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index fc54da6b3d..e8eb0a3d39 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -542,7 +542,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_str] __device__( cudf::size_type row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_str; + return locs[protobuf_detail::flat_index(static_cast(row), + static_cast(num_scalar), + static_cast(i))] + .offset >= 0 || + has_def_str; }; column_map[schema_idx] = extract_and_build_string_or_bytes_column(false, message_data, @@ -568,7 +572,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__( cudf::size_type row) { - return locs[row * num_scalar + i].offset >= 0 || has_def_bytes; + return locs[protobuf_detail::flat_index(static_cast(row), + static_cast(num_scalar), + static_cast(i))] + .offset >= 0 || + has_def_bytes; }; column_map[schema_idx] = extract_and_build_string_or_bytes_column(true, message_data, diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 7bb42e19d3..3edf50152b 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -46,7 +46,7 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( } auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1u) / threads; + auto const blocks = static_cast((total_count + threads - 1u) / threads); rmm::device_uvector d_lengths(total_count, stream, mr); thrust::transform( @@ -55,7 +55,8 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( thrust::make_counting_iterator(total_count), d_lengths.data(), [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - auto const& loc = child_locs[idx * ncf + ci]; + auto const& loc = child_locs[flat_index( + static_cast(idx), static_cast(ncf), static_cast(ci))]; return loc.offset >= 0 ? loc.length : 0; }); @@ -71,7 +72,10 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( thrust::make_counting_iterator(total_count), d_valid.data(), [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { - return child_locs[idx * ncf + ci].offset >= 0; + return child_locs[flat_index(static_cast(idx), + static_cast(ncf), + static_cast(ci))] + .offset >= 0; }); if (total_data > 0) { @@ -1291,7 +1295,11 @@ std::unique_ptr build_nested_struct_column( ci, num_child_fields, has_def_str] __device__(cudf::size_type row) { - return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + return (plocs[row].offset >= 0 && + flocs[flat_index(static_cast(row), + static_cast(num_child_fields), + static_cast(ci))] + .offset >= 0) || has_def_str; }; struct_children.push_back(extract_and_build_string_or_bytes_column(false, @@ -1329,7 +1337,10 @@ std::unique_ptr build_nested_struct_column( ci, num_child_fields, has_def_bytes] __device__(cudf::size_type row) { - return (plocs[row].offset >= 0 && flocs[row * num_child_fields + ci].offset >= 0) || + return (plocs[row].offset >= 0 && flocs[flat_index(static_cast(row), + static_cast(num_child_fields), + static_cast(ci))] + .offset >= 0) || has_def_bytes; }; struct_children.push_back(extract_and_build_string_or_bytes_column(true, diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index da2109f180..f5b7f161b6 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -319,6 +319,21 @@ __device__ inline bool get_field_data_location( return true; } +__device__ __host__ inline size_t flat_index(size_t row, size_t width, size_t col) +{ + return row * width + col; +} + +__device__ inline bool checked_add_int32(int32_t lhs, int32_t rhs, int32_t& out) +{ + auto const sum = static_cast(lhs) + rhs; + if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + return false; + } + out = static_cast(sum); + return true; +} + __device__ inline bool check_message_bounds(int32_t start, int32_t end_pos, cudf::size_type total_size, @@ -350,7 +365,7 @@ __device__ inline bool decode_tag(uint8_t const*& cur, cur += key_bytes; uint64_t fn = key >> 3; - if (fn == 0 || fn > static_cast(INT_MAX)) { + if (fn == 0 || fn > static_cast(spark_rapids_jni::MAX_FIELD_NUMBER)) { set_error_once(error_flag, ERR_FIELD_NUMBER); return false; } @@ -464,7 +479,9 @@ struct TopLevelLocationProvider { __device__ inline field_location get(int thread_idx, int32_t& data_offset) const { - auto loc = locations[thread_idx * num_fields + field_idx]; + auto loc = locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } return loc; } @@ -494,7 +511,9 @@ struct NestedLocationProvider { __device__ inline field_location get(int thread_idx, int32_t& data_offset) const { auto ploc = parent_locations[thread_idx]; - auto cloc = child_locations[thread_idx * num_fields + field_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; if (ploc.offset >= 0 && cloc.offset >= 0) { data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; } else { @@ -534,7 +553,9 @@ struct RepeatedMsgChildLocationProvider { __device__ inline field_location get(int thread_idx, int32_t& data_offset) const { auto mloc = msg_locations[thread_idx]; - auto cloc = child_locations[thread_idx * num_fields + field_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; if (mloc.offset >= 0 && cloc.offset >= 0) { data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; } else { @@ -957,7 +978,13 @@ struct extract_strided_count { int field_idx; int num_fields; - __device__ int32_t operator()(int row) const { return info[row * num_fields + field_idx].count; } + __device__ int32_t operator()(int row) const + { + return info[flat_index(static_cast(row), + static_cast(num_fields), + static_cast(field_idx))] + .count; + } }; /** diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index cfc332af79..d789f57e17 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -43,7 +43,8 @@ __global__ void scan_all_fields_kernel( if (row >= in.size()) return; for (int f = 0; f < num_fields; f++) { - locations[row * num_fields + f] = {-1, 0}; + locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = {-1, 0}; } if (in.nullable() && in.is_null(row)) { return; } @@ -89,7 +90,14 @@ __global__ void scan_all_fields_kernel( return; } // Record offset pointing to the actual data (after length prefix) - locations[row * num_fields + f] = {data_offset + len_bytes, static_cast(len)}; + int32_t data_location; + if (!checked_add_int32(data_offset, len_bytes, data_location)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_location, static_cast(len)}; } else { // For fixed-size and varint fields, record offset and compute length int field_size = get_wire_type_size(wt, cur, msg_end); @@ -97,7 +105,9 @@ __global__ void scan_all_fields_kernel( set_error_once(error_flag, ERR_FIELD_SIZE); return; } - locations[row * num_fields + f] = {data_offset, field_size}; + locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_offset, field_size}; } } @@ -305,12 +315,16 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in // Initialize repeated counts to 0 for (int f = 0; f < num_repeated_fields; f++) { - repeated_info[row * num_repeated_fields + f] = {0, 0}; + repeated_info[flat_index( + static_cast(row), static_cast(num_repeated_fields), static_cast(f))] = + {0, 0}; } // Initialize nested locations to not found for (int f = 0; f < num_nested_fields; f++) { - nested_locations[row * num_nested_fields + f] = {-1, 0}; + nested_locations[flat_index( + static_cast(row), static_cast(num_nested_fields), static_cast(f))] = { + -1, 0}; } if (in.nullable() && in.is_null(row)) { return; } @@ -336,12 +350,15 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in if (i >= 0) { int schema_idx = repeated_field_indices[i]; if (schema[schema_idx].depth == depth_level && - !count_repeated_element(cur, - msg_end, - wt, - schema[schema_idx].wire_type, - repeated_info[row * num_repeated_fields + i], - error_flag)) { + !count_repeated_element( + cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[flat_index(static_cast(row), + static_cast(num_repeated_fields), + static_cast(i))], + error_flag)) { return; } } @@ -349,12 +366,15 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in for (int i = 0; i < num_repeated_fields; i++) { int schema_idx = repeated_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - if (!count_repeated_element(cur, - msg_end, - wt, - schema[schema_idx].wire_type, - repeated_info[row * num_repeated_fields + i], - error_flag)) { + if (!count_repeated_element( + cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[flat_index(static_cast(row), + static_cast(num_repeated_fields), + static_cast(i))], + error_flag)) { return; } } @@ -378,8 +398,20 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in set_error_once(error_flag, ERR_OVERFLOW); return false; } - int32_t msg_offset = static_cast(cur - bytes - start) + len_bytes; - nested_locations[row * num_nested_fields + i] = {msg_offset, static_cast(len)}; + auto const rel_offset64 = static_cast(cur - bytes - start); + if (rel_offset64 < std::numeric_limits::min() || + rel_offset64 > std::numeric_limits::max()) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + int32_t msg_offset; + if (!checked_add_int32(static_cast(rel_offset64), len_bytes, msg_offset)) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + nested_locations[flat_index( + static_cast(row), static_cast(num_nested_fields), static_cast(i))] = + {msg_offset, static_cast(len)}; return true; }; @@ -517,7 +549,8 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, if (row >= num_parent_rows) return; for (int f = 0; f < num_fields; f++) { - output_locations[row * num_fields + f] = {-1, 0}; + output_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = {-1, 0}; } auto const& parent_loc = parent_locations[row]; @@ -567,15 +600,23 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, set_error_once(error_flag, ERR_OVERFLOW); return; } - output_locations[row * num_fields + f] = {data_offset + len_bytes, - static_cast(len)}; + int32_t data_location; + if (!checked_add_int32(data_offset, len_bytes, data_location)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + output_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_location, static_cast(len)}; } else { int field_size = get_wire_type_size(wt, cur, nested_end); if (field_size < 0) { set_error_once(error_flag, ERR_FIELD_SIZE); return; } - output_locations[row * num_fields + f] = {data_offset, field_size}; + output_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_offset, field_size}; } } } @@ -612,7 +653,9 @@ __global__ void scan_repeated_message_children_kernel( // Initialize child locations to not found for (int f = 0; f < num_child_fields; f++) { - child_locs[occ_idx * num_child_fields + f] = {-1, 0}; + child_locs[flat_index(static_cast(occ_idx), + static_cast(num_child_fields), + static_cast(f))] = {-1, 0}; } auto const& msg_loc = msg_locs[occ_idx]; @@ -659,8 +702,15 @@ __global__ void scan_repeated_message_children_kernel( set_error_once(error_flag, ERR_OVERFLOW); return; } - child_locs[occ_idx * num_child_fields + f] = {data_offset + len_bytes, - static_cast(len)}; + int32_t data_location; + if (!checked_add_int32(data_offset, len_bytes, data_location)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + child_locs[flat_index(static_cast(occ_idx), + static_cast(num_child_fields), + static_cast(f))] = {data_location, + static_cast(len)}; } else { // For varint/fixed types, store offset and estimated length int32_t data_length = 0; @@ -685,7 +735,9 @@ __global__ void scan_repeated_message_children_kernel( } data_length = 8; } - child_locs[occ_idx * num_child_fields + f] = {data_offset, data_length}; + child_locs[flat_index(static_cast(occ_idx), + static_cast(num_child_fields), + static_cast(f))] = {data_offset, data_length}; } } } @@ -727,7 +779,9 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, // Initialize counts for (int ri = 0; ri < num_repeated; ri++) { - repeated_info[row * num_repeated + ri] = {0, 0}; + repeated_info[flat_index( + static_cast(row), static_cast(num_repeated), static_cast(ri))] = {0, + 0}; } auto const& parent_loc = parent_locs[row]; @@ -762,7 +816,9 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, msg_end, wt, schema[schema_idx].wire_type, - repeated_info[row * num_repeated + ri], + repeated_info[flat_index(static_cast(row), + static_cast(num_repeated), + static_cast(ri))], error_flag)) { return; } @@ -867,7 +923,9 @@ __global__ void compute_nested_struct_locations_kernel( int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= total_count) return; - nested_locs[idx] = child_locs[idx * num_child_fields + child_idx]; + nested_locs[idx] = child_locs[flat_index(static_cast(idx), + static_cast(num_child_fields), + static_cast(child_idx))]; auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { @@ -897,7 +955,9 @@ __global__ void compute_grandchild_parent_locations_kernel( if (row >= num_rows) return; auto const& parent_loc = parent_locs[row]; - auto const& child_loc = child_locs[row * num_child_fields + child_idx]; + auto const& child_loc = child_locs[flat_index(static_cast(row), + static_cast(num_child_fields), + static_cast(child_idx))]; if (parent_loc.offset >= 0 && child_loc.offset >= 0) { // Absolute offset = parent offset + child's relative offset @@ -993,7 +1053,8 @@ __global__ void extract_strided_locations_kernel(field_location const* nested_lo { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= num_rows) return; - parent_locs[row] = nested_locations[row * num_fields + field_idx]; + parent_locs[row] = nested_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(field_idx))]; } // ============================================================================ @@ -1015,7 +1076,10 @@ __global__ void check_required_fields_kernel( if (row >= num_rows) return; for (int f = 0; f < num_fields; f++) { - if (is_required[f] != 0 && locations[row * num_fields + f].offset < 0) { + if (is_required[f] != 0 && locations[flat_index(static_cast(row), + static_cast(num_fields), + static_cast(f))] + .offset < 0) { // Required field is missing - set error flag set_error_once(error_flag, ERR_REQUIRED); return; // No need to check other fields for this row diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 887e03d76d..4540d75f89 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -2361,6 +2361,23 @@ void testFailfastFieldNumberZero() { } } + @Test + void testFailfastFieldNumberAboveSpecLimit() { + // Protobuf field numbers must be <= 2^29 - 1. + Byte[] row = concat(box(tag(1 << 29, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + true)) { + } + }); + } + } + @Test void testFailfastValidDataDoesNotThrow() { // Valid protobuf should not throw even with failOnErrors = true From 96b20e0e5cef1c02272770c4a580b7417b036ed1 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 10:10:09 +0800 Subject: [PATCH 075/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 6 ++++-- src/main/cpp/src/protobuf_kernels.cu | 4 ++++ .../spark/rapids/jni/ProtobufSchemaDescriptor.java | 11 ++++++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 3edf50152b..ad1164b11f 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -572,8 +572,9 @@ std::unique_ptr build_repeated_string_column( thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + int32_t total_count_i32 = static_cast(total_count); CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, + &total_count_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); @@ -767,8 +768,9 @@ std::unique_ptr build_repeated_struct_column( thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + int32_t total_count_i32 = static_cast(total_count); CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, + &total_count_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index d789f57e17..4f5fe5ec52 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -533,6 +533,10 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co * Scan nested message fields. * Each row represents a nested message at a specific parent location. * This kernel finds fields within the nested message bytes. + * + * Note: this path intentionally keeps a linear child-field scan. Unlike repeated-message child + * scanning, previous benchmarking did not show a stable win from adding a host-built lookup table + * here, so we keep the simpler implementation unless that trade-off changes. */ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, cudf::size_type message_data_size, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 6bedc210ba..9b79041747 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -26,9 +26,11 @@ *

    Use this class instead of passing 15+ individual arrays through the JNI boundary. * Validation is performed once in the constructor (and again on deserialization). * - *

    All arrays are defensively copied in the constructor to guarantee immutability. - * Package-private field access from {@link Protobuf} is safe because the stored arrays - * cannot be mutated by the original caller. + *

    All arrays provided to the constructor are defensively copied to guarantee immutability. + * During deserialization, {@code defaultReadObject()} reconstructs a fresh object graph and + * {@link #readObject(java.io.ObjectInputStream)} re-validates the schema invariants before the + * instance becomes visible. Package-private field access from {@link Protobuf} is therefore safe + * because constructor callers cannot retain mutable aliases into the stored arrays. */ public final class ProtobufSchemaDescriptor implements java.io.Serializable { private static final long serialVersionUID = 1L; @@ -106,6 +108,9 @@ public ProtobufSchemaDescriptor( private void readObject(java.io.ObjectInputStream in) throws java.io.IOException, ClassNotFoundException { + // defaultReadObject() reconstructs new array objects from the serialized stream; we do not + // receive caller-owned array aliases here. Re-run validate() so deserialization cannot bypass + // the constructor's schema invariants. in.defaultReadObject(); try { validate(fieldNumbers, parentIndices, depthLevels, wireTypes, outputTypeIds, From 87c5e99c3ddbfd9753e333ee305d756feb1727b3 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 11:32:25 +0800 Subject: [PATCH 076/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.hpp | 12 ++++++ .../rapids/jni/ProtobufSchemaDescriptor.java | 6 +++ .../jni/ProtobufSchemaDescriptorTest.java | 37 +++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 036735a8fb..9df727fd1d 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -146,6 +147,7 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) fail_size("enum_valid_values", context.enum_valid_values.size()); if (context.enum_names.size() != num_fields) fail_size("enum_names", context.enum_names.size()); + std::set> seen_field_numbers; for (size_t i = 0; i < num_fields; ++i) { auto const& field = context.schema[i]; auto const& type = context.schema_output_types[i]; @@ -157,10 +159,20 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) throw std::invalid_argument("protobuf decode context: invalid field number at field " + std::to_string(i)); } + if (field.depth < 0 || field.depth >= MAX_NESTING_DEPTH) { + throw std::invalid_argument( + "protobuf decode context: field depth exceeds supported limit at field " + + std::to_string(i)); + } if (field.parent_idx < -1 || field.parent_idx >= static_cast(i)) { throw std::invalid_argument("protobuf decode context: invalid parent index at field " + std::to_string(i)); } + if (!seen_field_numbers.emplace(field.parent_idx, field.field_number).second) { + throw std::invalid_argument( + "protobuf decode context: duplicate field number under same parent at field " + + std::to_string(i)); + } if (field.parent_idx == -1) { if (field.depth != 0) { throw std::invalid_argument( diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 9b79041747..c1ee8b5202 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -35,6 +35,7 @@ public final class ProtobufSchemaDescriptor implements java.io.Serializable { private static final long serialVersionUID = 1L; private static final int MAX_FIELD_NUMBER = (1 << 29) - 1; + private static final int MAX_NESTING_DEPTH = 10; private static final int STRUCT_TYPE_ID = ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(); private static final int STRING_TYPE_ID = ai.rapids.cudf.DType.STRING.getTypeId().getNativeId(); private static final int LIST_TYPE_ID = ai.rapids.cudf.DType.LIST.getTypeId().getNativeId(); @@ -184,6 +185,11 @@ private static void validate( "Invalid field number at index " + i + ": " + fieldNumbers[i] + " (must be 1-" + MAX_FIELD_NUMBER + ")"); } + if (depthLevels[i] < 0 || depthLevels[i] >= MAX_NESTING_DEPTH) { + throw new IllegalArgumentException( + "Invalid depth at index " + i + ": " + depthLevels[i] + + " (must be 0-" + (MAX_NESTING_DEPTH - 1) + ")"); + } int pi = parentIndices[i]; if (pi < -1 || pi >= i) { throw new IllegalArgumentException( diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index 85a5daa556..8f4d16ba94 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -177,4 +177,41 @@ void testEncodingCompatibilityValidation() { new int[][]{{0, 1}}, new byte[][][]{new byte[][]{"A".getBytes(), "B".getBytes()}})); } + + @Test + void testDepthAboveSupportedLimitRejected() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + new int[]{-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + new int[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + new int[]{Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, + Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_LEN, + Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_VARINT}, + new int[]{ + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId(), + ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new long[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false, false, false, false, false, false, false}, + new byte[][]{null, null, null, null, null, null, null, null, null, null, null}, + new int[][]{null, null, null, null, null, null, null, null, null, null, null}, + new byte[][][]{null, null, null, null, null, null, null, null, null, null, null})); + } } From fca9ea7be6067552badbc02368b808507f867751 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 12:21:32 +0800 Subject: [PATCH 077/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 77 +++++++++++++++---- src/main/cpp/src/protobuf_common.cuh | 4 + .../jni/ProtobufSchemaDescriptorTest.java | 52 +++++++++++++ .../nvidia/spark/rapids/jni/ProtobufTest.java | 16 ++++ 4 files changed, 132 insertions(+), 17 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index e8eb0a3d39..b7bf6ad71a 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -22,6 +22,60 @@ using namespace spark_rapids_jni::protobuf_detail; namespace spark_rapids_jni { +namespace { + +void apply_parent_mask_to_row_aligned_column(cudf::column& col, + cudf::bitmask_type const* parent_mask_ptr, + cudf::size_type parent_null_count, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto child_view = col.mutable_view(); + CUDF_EXPECTS(child_view.size() == num_rows, + "struct child size must match parent row count for null propagation"); + + if (child_view.nullable()) { + auto const child_mask_words = + cudf::num_bitmask_words(static_cast(child_view.size() + child_view.offset())); + std::array masks{child_view.null_mask(), parent_mask_ptr}; + std::array begin_bits{child_view.offset(), 0}; + auto const valid_count = cudf::detail::inplace_bitmask_and( + cudf::device_span(child_view.null_mask(), child_mask_words), + cudf::host_span(masks.data(), masks.size()), + cudf::host_span(begin_bits.data(), begin_bits.size()), + child_view.size(), + stream); + col.set_null_count(child_view.size() - valid_count); + } else { + auto child_mask = cudf::detail::copy_bitmask(parent_mask_ptr, 0, num_rows, stream, mr); + col.set_null_mask(std::move(child_mask), parent_null_count); + } +} + +void propagate_struct_nulls_to_descendants(cudf::column& struct_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (struct_col.type().id() != cudf::type_id::STRUCT || struct_col.null_count() == 0) { return; } + + auto const struct_view = struct_col.view(); + auto const* struct_mask_ptr = struct_view.null_mask(); + auto const num_rows = struct_view.size(); + auto const null_count = struct_col.null_count(); + + for (cudf::size_type i = 0; i < struct_col.num_children(); ++i) { + auto& child = struct_col.child(i); + apply_parent_mask_to_row_aligned_column( + child, struct_mask_ptr, null_count, num_rows, stream, mr); + if (child.type().id() == cudf::type_id::STRUCT) { + propagate_struct_nulls_to_descendants(child, stream, mr); + } + } +} + +} // namespace + std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, ProtobufDecodeContext const& context, rmm::cuda_stream_view stream) @@ -1061,26 +1115,15 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } // cuDF struct child views do not inherit parent nulls. Push PERMISSIVE invalid-enum nulls - // down into every top-level child so extracted fields respect "null struct => null field". + // down into every top-level child, then recursively into nested STRUCT descendants, so + // callers that access grandchildren directly still observe logically-null rows. if (has_enum_fields && struct_null_count > 0) { auto const* struct_mask_ptr = static_cast(struct_mask.data()); for (auto& child : top_level_children) { - auto child_view = child->mutable_view(); - if (child_view.nullable()) { - auto const child_mask_words = - cudf::num_bitmask_words(static_cast(child_view.size() + child_view.offset())); - std::array masks{child_view.null_mask(), struct_mask_ptr}; - std::array begin_bits{child_view.offset(), 0}; - auto const valid_count = cudf::detail::inplace_bitmask_and( - cudf::device_span(child_view.null_mask(), child_mask_words), - cudf::host_span(masks.data(), masks.size()), - cudf::host_span(begin_bits.data(), begin_bits.size()), - child_view.size(), - stream); - child->set_null_count(child_view.size() - valid_count); - } else { - auto child_mask = cudf::detail::copy_bitmask(struct_mask_ptr, 0, num_rows, stream, mr); - child->set_null_mask(std::move(child_mask), struct_null_count); + apply_parent_mask_to_row_aligned_column( + *child, struct_mask_ptr, struct_null_count, num_rows, stream, mr); + if (child->type().id() == cudf::type_id::STRUCT) { + propagate_struct_nulls_to_descendants(*child, stream, mr); } } } diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index f5b7f161b6..88cc97532d 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -371,6 +371,10 @@ __device__ inline bool decode_tag(uint8_t const*& cur, } tag.field_number = static_cast(fn); tag.wire_type = static_cast(key & 0x7); + if (tag.wire_type == WT_SGROUP || tag.wire_type == WT_EGROUP) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } return true; } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index 8f4d16ba94..12efebe685 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -18,7 +18,15 @@ import org.junit.jupiter.api.Test; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertThrows; public class ProtobufSchemaDescriptorTest { @@ -214,4 +222,48 @@ void testDepthAboveSupportedLimitRejected() { new int[][]{null, null, null, null, null, null, null, null, null, null, null}, new byte[][][]{null, null, null, null, null, null, null, null, null, null, null})); } + + @Test + void testSerializationRoundTripPreservesContentsAndIsolation() throws Exception { + ProtobufSchemaDescriptor original = new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{ai.rapids.cudf.DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ENUM_STRING}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{7}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"def".getBytes()}, + new int[][]{{0, 1}}, + new byte[][][]{new byte[][]{"A".getBytes(), "B".getBytes()}}); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { + oos.writeObject(original); + } + + ProtobufSchemaDescriptor roundTrip; + try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(baos.toByteArray()))) { + roundTrip = (ProtobufSchemaDescriptor) ois.readObject(); + } + + assertEquals(original.numFields(), roundTrip.numFields()); + assertArrayEquals(original.fieldNumbers, roundTrip.fieldNumbers); + assertArrayEquals(original.defaultStrings[0], roundTrip.defaultStrings[0]); + assertArrayEquals(original.enumValidValues[0], roundTrip.enumValidValues[0]); + assertArrayEquals(original.enumNames[0][0], roundTrip.enumNames[0][0]); + assertArrayEquals(original.enumNames[0][1], roundTrip.enumNames[0][1]); + assertNotSame(original.defaultStrings, roundTrip.defaultStrings); + assertNotSame(original.defaultStrings[0], roundTrip.defaultStrings[0]); + assertNotSame(original.enumValidValues, roundTrip.enumValidValues); + assertNotSame(original.enumValidValues[0], roundTrip.enumValidValues[0]); + assertNotSame(original.enumNames, roundTrip.enumNames); + assertNotSame(original.enumNames[0], roundTrip.enumNames[0]); + assertNotSame(original.enumNames[0][0], roundTrip.enumNames[0][0]); + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 4540d75f89..51bfcbddc8 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -2378,6 +2378,22 @@ void testFailfastFieldNumberAboveSpecLimit() { } } + @Test + void testFailfastUnknownEndGroupWireType() { + Byte[] row = concat(box(tag(5, 4))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + true)) { + } + }); + } + } + @Test void testFailfastValidDataDoesNotThrow() { // Valid protobuf should not throw even with failOnErrors = true From 8e5473ca86fa5f2d99110d46dbab851a14f22a8d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 15:42:04 +0800 Subject: [PATCH 078/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 25 ++++-- src/main/cpp/src/protobuf_builders.cu | 24 +++++- src/main/cpp/src/protobuf_common.cuh | 19 +++-- src/main/cpp/src/protobuf_kernels.cu | 7 ++ .../nvidia/spark/rapids/jni/ProtobufTest.java | 84 ++++++++++++++++--- 5 files changed, 133 insertions(+), 26 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index b7bf6ad71a..4b06dd6084 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -157,7 +157,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.value())); auto d_in = cudf::column_device_view::create(binary_input, stream); - // Identify repeated and nested fields at depth 0 std::vector repeated_field_indices; std::vector nested_field_indices; @@ -327,8 +326,16 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& d_error.data()); // Required-field validation applies to all scalar leaves, not just top-level numerics. - maybe_check_required_fields( - d_locations.data(), scalar_field_indices, schema, num_rows, d_error.data(), stream, mr); + maybe_check_required_fields(d_locations.data(), + scalar_field_indices, + schema, + num_rows, + binary_input.null_count() > 0 ? binary_input.null_mask() : nullptr, + binary_input.offset(), + nullptr, + d_error.data(), + stream, + mr); // Batched scalar extraction: group non-special fixed-width fields by extraction // category and extract all fields of each category with a single 2D kernel launch. @@ -655,8 +662,16 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Required top-level nested messages are tracked in d_nested_locations during the scan/count // pass. - maybe_check_required_fields( - d_nested_locations.data(), nested_field_indices, schema, num_rows, d_error.data(), stream, mr); + maybe_check_required_fields(d_nested_locations.data(), + nested_field_indices, + schema, + num_rows, + binary_input.null_count() > 0 ? binary_input.null_mask() : nullptr, + binary_input.offset(), + nullptr, + d_error.data(), + stream, + mr); // Process repeated fields (three-phase: offsets → combined scan → build columns) if (num_repeated > 0) { diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index ad1164b11f..e026f09f6c 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -848,8 +848,16 @@ std::unique_ptr build_repeated_struct_column( static_cast(d_child_lookup.size())); // Enforce proto2 required semantics for fields inside each repeated message occurrence. - maybe_check_required_fields( - d_child_locs.data(), child_field_indices, schema, total_count, d_error.data(), stream, mr); + maybe_check_required_fields(d_child_locs.data(), + child_field_indices, + schema, + total_count, + nullptr, + 0, + nullptr, + d_error.data(), + stream, + mr); // Note: We no longer need to copy child_locs to host because: // 1. All scalar extraction kernels access d_child_locs directly on device @@ -1152,8 +1160,16 @@ std::unique_ptr build_nested_struct_column( d_error.data()); // Enforce proto2 required semantics for direct children of this nested message. - maybe_check_required_fields( - d_child_locations.data(), child_field_indices, schema, num_rows, d_error.data(), stream, mr); + maybe_check_required_fields(d_child_locations.data(), + child_field_indices, + schema, + num_rows, + nullptr, + 0, + d_parent_locs.data(), + d_error.data(), + stream, + mr); std::vector> struct_children; for (int ci = 0; ci < num_child_fields; ci++) { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 88cc97532d..3c0cb6dc8a 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -371,10 +371,6 @@ __device__ inline bool decode_tag(uint8_t const*& cur, } tag.field_number = static_cast(fn); tag.wire_type = static_cast(key & 0x7); - if (tag.wire_type == WT_SGROUP || tag.wire_type == WT_EGROUP) { - set_error_once(error_flag, ERR_WIRE_TYPE); - return false; - } return true; } @@ -1182,12 +1178,18 @@ __global__ void check_required_fields_kernel(field_location const* locations, uint8_t const* is_required, int num_fields, int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, int* error_flag); inline void maybe_check_required_fields(field_location const* locations, std::vector const& field_indices, std::vector const& schema, int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, int* error_flag, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -1211,7 +1213,14 @@ inline void maybe_check_required_fields(field_location const* locations, auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); check_required_fields_kernel<<>>( - locations, d_is_required.data(), static_cast(field_indices.size()), num_rows, error_flag); + locations, + d_is_required.data(), + static_cast(field_indices.size()), + num_rows, + input_null_mask, + input_offset, + parent_locs, + error_flag); } __global__ void validate_enum_values_kernel(int32_t const* values, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 4f5fe5ec52..74603aac03 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -1074,10 +1074,17 @@ __global__ void check_required_fields_kernel( uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) int num_fields, int num_rows, + cudf::bitmask_type const* input_null_mask, // optional top-level input null mask + cudf::size_type input_offset, // bit offset for sliced top-level input + field_location const* parent_locs, // [num_rows] optional parent presence for nested rows int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; + if (input_null_mask != nullptr && !cudf::bit_is_set(input_null_mask, row + input_offset)) { + return; + } + if (parent_locs != nullptr && parent_locs[row].offset < 0) return; for (int f = 0; f < num_fields; f++) { if (is_required[f] != 0 && locations[flat_index(static_cast(row), diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 51bfcbddc8..c7853bd081 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1316,6 +1316,30 @@ void testRequiredFieldWithMultipleRows() { } } + @Test + void testRequiredFieldIgnoresNullInputRow_Failfast() { + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + Byte[] row1 = null; + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + true); + ColumnVector idCol = actualStruct.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actualStruct.copyToHost(); + HostColumnVector hostId = idCol.copyToHost()) { + assertEquals(0, actualStruct.getNullCount(), "Null input rows keep the top-level struct row"); + assertFalse(hostStruct.isNull(0), "Present required field should keep row 0 valid"); + assertFalse(hostStruct.isNull(1), "Null input row should not trigger required-field failure"); + assertEquals(1, idCol.getNullCount(), "The required child value should be null on the null input row"); + assertTrue(hostId.isNull(1), "Null input row should produce a null child value, not ERR_REQUIRED"); + } + } + @Test void testRequiredNestedMessageMissing_Failfast() { // message Outer { required Inner detail = 1; } @@ -1383,6 +1407,40 @@ void testRequiredFieldInsideNestedMessageMissing_Failfast() { } } + @Test + void testAbsentNestedParentSkipsRequiredChildCheck_Failfast() { + // message Outer { optional Inner detail = 1; } + // message Inner { required int32 id = 1; } + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{false, true}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + true); + ColumnVector detailCol = actual.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostDetail = detailCol.copyToHost()) { + assertEquals(0, actual.getNullCount(), "Outer row should remain valid"); + assertFalse(hostStruct.isNull(0), "Top-level row should not be null"); + assertEquals(1, detailCol.getNullCount(), "Absent nested parent should stay null"); + assertTrue(hostDetail.isNull(0), "Missing optional nested struct should skip required-child error"); + } + } + // ============================================================================ // Default Value Tests (API accepts parameters, CUDA fill not yet implemented) // ============================================================================ @@ -2379,18 +2437,20 @@ void testFailfastFieldNumberAboveSpecLimit() { } @Test - void testFailfastUnknownEndGroupWireType() { - Byte[] row = concat(box(tag(5, 4))); - try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { - assertThrows(ai.rapids.cudf.CudfException.class, () -> { - try (ColumnVector result = decodeAllFields( - input.getColumn(0), - new int[]{1}, - new int[]{DType.INT64.getTypeId().getNativeId()}, - new int[]{Protobuf.ENC_DEFAULT}, - true)) { - } - }); + void testUnknownEndGroupWireTypeDoesNotAbortDecode() { + Byte[] row = concat( + box(tag(5, 4)), + box(tag(1, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expected = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected); + ColumnVector actual = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actual); } } From 95e258698cfa24fa916f843280acae1c7d1a3976 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 16:56:26 +0800 Subject: [PATCH 079/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_kernels.cu | 6 ++++++ .../java/com/nvidia/spark/rapids/jni/Protobuf.java | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 74603aac03..24d4a771db 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -28,6 +28,12 @@ namespace spark_rapids_jni::protobuf_detail { * * For "last one wins" semantics (protobuf standard for repeated scalars), * we continue scanning even after finding a field. + * + * If a row hits a parse error that leaves the cursor in an unsafe state (for example, malformed + * varint bytes or a schema-matching field with the wrong wire type), the scan aborts for that row + * instead of guessing where the next field begins. In permissive mode this means fields after the + * error position are treated as "not found" and therefore fall back to the usual null/default + * missing-field semantics. */ __global__ void scan_all_fields_kernel( cudf::column_device_view const d_in, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java index 1660457d79..512ef9e3c7 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java @@ -44,6 +44,12 @@ *

  • LENGTH_DELIMITED: {@code string}, {@code bytes}, nested {@code message}
  • *
  • Nested messages and repeated fields
  • *
+ * + *

In permissive mode ({@code failOnErrors=false}), if decoding encounters a row-local parse + * error from which it cannot safely recover its cursor position (for example, an unexpected wire + * type or malformed varint), scanning for that row stops at the error position. Fields that appear + * later in the same message are therefore treated as "not found" and follow the normal + * missing-field semantics (nulls or defaults, depending on the schema metadata). */ public class Protobuf { static { @@ -66,7 +72,10 @@ public class Protobuf { * * @param binaryInput column of type LIST<INT8/UINT8> where each row is one protobuf message. * @param schema descriptor containing flattened schema arrays (field numbers, types, defaults, etc.) - * @param failOnErrors if true, throw an exception on malformed protobuf messages. + * @param failOnErrors if true, throw an exception on malformed protobuf messages. If false, + * malformed rows are handled permissively; when a row-local parse error + * prevents safe resynchronization, later fields in that same row are treated + * as absent rather than continuing from an uncertain cursor position. * @return a cudf STRUCT column with nested structure. */ public static ColumnVector decodeToStruct(ColumnView binaryInput, From 67f2db8ba0b44a02059c36cf6ff6aec249d3ea6f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 21:59:16 +0800 Subject: [PATCH 080/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 64 +++++++++++++++++++-------- src/main/cpp/src/protobuf_common.cuh | 26 +++++++---- 2 files changed, 63 insertions(+), 27 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index e026f09f6c..85f3c023de 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -355,7 +355,8 @@ std::unique_ptr build_enum_string_column( int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, - int32_t const* top_row_indices) + int32_t const* top_row_indices, + bool propagate_invalid_rows) { auto const threads = THREADS_PER_BLOCK; auto const blocks = static_cast((num_rows + threads - 1u) / threads); @@ -373,8 +374,13 @@ std::unique_ptr build_enum_string_column( lookup.d_valid_enums.data(), static_cast(valid_enums.size()), num_rows); - propagate_invalid_enum_flags_to_rows( - d_item_has_invalid_enum, d_row_has_invalid_enum, num_rows, top_row_indices, stream, mr); + propagate_invalid_enum_flags_to_rows(d_item_has_invalid_enum, + d_row_has_invalid_enum, + num_rows, + top_row_indices, + propagate_invalid_rows, + stream, + mr); return build_enum_string_values_column(enum_values, valid, lookup, num_rows, stream, mr); } @@ -390,6 +396,7 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column std::vector> const& enum_name_bytes, rmm::device_uvector& d_row_has_invalid_enum, int32_t const* top_row_indices, + bool propagate_invalid_rows, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -428,8 +435,13 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column lookup.d_valid_enums.data(), static_cast(valid_enums.size()), total_count); - propagate_invalid_enum_flags_to_rows( - d_elem_has_invalid_enum, d_row_has_invalid_enum, total_count, top_row_indices, stream, mr); + propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, + d_row_has_invalid_enum, + total_count, + top_row_indices, + propagate_invalid_rows, + stream, + mr); return build_enum_string_values_column(enum_values, valid, lookup, total_count, stream, mr); } @@ -493,6 +505,7 @@ std::unique_ptr build_repeated_enum_string_column( d_row_has_invalid_enum, total_count, d_top_row_indices.data(), + true, stream, mr); @@ -663,7 +676,8 @@ std::unique_ptr build_nested_struct_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, int32_t const* top_row_indices, - int depth); + int depth, + bool propagate_invalid_rows); // Forward declaration -- build_repeated_child_list_column is defined after // build_nested_struct_column but both build_repeated_struct_column and build_nested_struct_column @@ -690,7 +704,8 @@ std::unique_ptr build_repeated_child_list_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, int32_t const* top_row_indices, - int depth); + int depth, + bool propagate_invalid_rows); std::unique_ptr build_repeated_struct_column( cudf::column_view const& binary_input, @@ -896,7 +911,8 @@ std::unique_ptr build_repeated_struct_column( stream, mr, d_top_row_indices.data(), - 1)); + 1, + false)); continue; } @@ -934,7 +950,8 @@ std::unique_ptr build_repeated_struct_column( d_error, stream, mr, - d_top_row_indices.data())); + d_top_row_indices.data(), + true)); break; } case cudf::type_id::STRING: { @@ -955,6 +972,7 @@ std::unique_ptr build_repeated_struct_column( enum_names[child_schema_idx], d_row_has_invalid_enum, d_top_row_indices.data(), + true, d_error, stream, mr)); @@ -1043,7 +1061,8 @@ std::unique_ptr build_repeated_struct_column( stream, mr, d_top_row_indices.data(), - 0)); + 0, + false)); } break; } @@ -1101,7 +1120,8 @@ std::unique_ptr build_nested_struct_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, int32_t const* top_row_indices, - int depth) + int depth, + bool propagate_invalid_rows) { CUDF_EXPECTS(depth < MAX_NESTED_STRUCT_DECODE_DEPTH, "Nested protobuf struct depth exceeds supported decode recursion limit"); @@ -1201,7 +1221,8 @@ std::unique_ptr build_nested_struct_column( stream, mr, top_row_indices, - depth)); + depth, + false)); continue; } @@ -1239,7 +1260,8 @@ std::unique_ptr build_nested_struct_column( d_error, stream, mr, - top_row_indices)); + top_row_indices, + propagate_invalid_rows)); break; } case cudf::type_id::STRING: { @@ -1276,7 +1298,8 @@ std::unique_ptr build_nested_struct_column( num_rows, stream, mr, - top_row_indices)); + top_row_indices, + propagate_invalid_rows)); } else { { int err_val = ERR_MISSING_ENUM_META; @@ -1410,7 +1433,8 @@ std::unique_ptr build_nested_struct_column( stream, mr, top_row_indices, - depth + 1)); + depth + 1, + propagate_invalid_rows)); break; } default: struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); break; @@ -1455,7 +1479,8 @@ std::unique_ptr build_repeated_child_list_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, int32_t const* top_row_indices, - int depth) + int depth, + bool propagate_invalid_rows) { auto const threads = THREADS_PER_BLOCK; auto const blocks = static_cast((num_parent_rows + threads - 1u) / threads); @@ -1592,7 +1617,8 @@ std::unique_ptr build_repeated_child_list_column( d_error, stream, mr, - d_rep_top_row_indices.data()); + d_rep_top_row_indices.data(), + propagate_invalid_rows); } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { if (elem_type_id == cudf::type_id::STRING && schema[child_schema_idx].encoding == spark_rapids_jni::ENC_ENUM_STRING) { @@ -1630,6 +1656,7 @@ std::unique_ptr build_repeated_child_list_column( d_row_has_invalid_enum, total_rep_count, d_rep_top_row_indices.data(), + propagate_invalid_rows, stream, mr); child_values = @@ -1702,7 +1729,8 @@ std::unique_ptr build_repeated_child_list_column( stream, mr, d_rep_top_row_indices.data(), - depth + 1); + depth + 1, + propagate_invalid_rows); } } else { child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 3c0cb6dc8a..584463eda2 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1252,10 +1252,11 @@ inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const rmm::device_uvector& row_invalid, int num_items, int32_t const* top_row_indices, + bool propagate_to_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - if (num_items == 0 || row_invalid.size() == 0) { return; } + if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } if (top_row_indices == nullptr) { CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), @@ -1298,6 +1299,7 @@ inline void validate_enum_and_propagate_rows(rmm::device_uvector const& rmm::device_uvector& row_invalid, int num_items, int32_t const* top_row_indices, + bool propagate_to_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { @@ -1322,7 +1324,7 @@ inline void validate_enum_and_propagate_rows(rmm::device_uvector const& num_items); propagate_invalid_enum_flags_to_rows( - item_invalid, row_invalid, num_items, top_row_indices, stream, mr); + item_invalid, row_invalid, num_items, top_row_indices, propagate_to_rows, stream, mr); } // ============================================================================ @@ -1349,7 +1351,8 @@ std::unique_ptr build_enum_string_column( int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, - int32_t const* top_row_indices = nullptr); + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true); // Complex builder forward declarations std::unique_ptr build_repeated_enum_string_column( @@ -1405,7 +1408,8 @@ std::unique_ptr build_nested_struct_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, int32_t const* top_row_indices, - int depth); + int depth, + bool propagate_invalid_rows = true); std::unique_ptr build_repeated_child_list_column( uint8_t const* message_data, @@ -1429,7 +1433,8 @@ std::unique_ptr build_repeated_child_list_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, int32_t const* top_row_indices, - int depth); + int depth, + bool propagate_invalid_rows = true); std::unique_ptr build_repeated_struct_column( cudf::column_view const& binary_input, @@ -1480,7 +1485,7 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( rmm::device_uvector lengths(num_rows, stream, mr); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (num_rows + threads - 1) / threads; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); extract_lengths_kernel<<>>( length_provider, num_rows, lengths.data(), has_default, def_len); @@ -1544,7 +1549,8 @@ inline std::unique_ptr extract_typed_column( rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, - int32_t const* top_row_indices = nullptr) + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true) { switch (dt.id()) { case cudf::type_id::BOOL8: { @@ -1591,6 +1597,7 @@ inline std::unique_ptr extract_typed_column( d_row_has_invalid_enum, num_items, top_row_indices, + propagate_invalid_rows, stream, mr); } @@ -1733,8 +1740,9 @@ inline std::unique_ptr build_repeated_scalar_column( thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + int32_t total_count_i32 = static_cast(total_count); CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count, + &total_count_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); @@ -1742,7 +1750,7 @@ inline std::unique_ptr build_repeated_scalar_column( rmm::device_uvector values(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1) / threads; + auto const blocks = static_cast((total_count + threads - 1u) / threads); int encoding = field_desc.encoding; bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); From 06c15ee72aaae0d51d6ea876c700a5c5124b7364 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 22:45:48 +0800 Subject: [PATCH 081/107] address greptile comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 4b06dd6084..feaa8d07b4 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -1075,7 +1075,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream, mr, nullptr, - 0); + 0, + false); } } From 14c66442d9e16f19604a38fd758cea079dc8af6f Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 13 Mar 2026 23:10:45 +0800 Subject: [PATCH 082/107] bugfix Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf_builders.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 85f3c023de..4b28cdc4a8 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -951,7 +951,7 @@ std::unique_ptr build_repeated_struct_column( stream, mr, d_top_row_indices.data(), - true)); + false)); break; } case cudf::type_id::STRING: { @@ -972,7 +972,7 @@ std::unique_ptr build_repeated_struct_column( enum_names[child_schema_idx], d_row_has_invalid_enum, d_top_row_indices.data(), - true, + false, d_error, stream, mr)); From 7a06053a7c900188774b347d7d9ad18e75259079 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 14 Mar 2026 12:51:18 +0800 Subject: [PATCH 083/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 2 +- src/main/cpp/src/protobuf_builders.cu | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index feaa8d07b4..b7ca98fe70 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -1076,7 +1076,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr, nullptr, 0, - false); + true); } } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 4b28cdc4a8..7b70b6e115 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -595,7 +595,7 @@ std::unique_ptr build_repeated_string_column( // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1u) / threads; + auto const blocks = static_cast((total_count + threads - 1u) / threads); RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; extract_lengths_kernel <<>>(loc_provider, total_count, str_lengths.data()); @@ -822,7 +822,7 @@ std::unique_ptr build_repeated_struct_column( rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); { auto const occ_threads = THREADS_PER_BLOCK; - auto const occ_blocks = (total_count + occ_threads - 1u) / occ_threads; + auto const occ_blocks = static_cast((total_count + occ_threads - 1u) / occ_threads); compute_msg_locations_from_occurrences_kernel<<>>( d_occurrences.data(), list_offsets, @@ -845,7 +845,7 @@ std::unique_ptr build_repeated_struct_column( auto& d_error = d_error_top; auto const threads = THREADS_PER_BLOCK; - auto const blocks = (total_count + threads - 1u) / threads; + auto const blocks = static_cast((total_count + threads - 1u) / threads); // Use a custom kernel to scan child fields within message occurrences // This is similar to scan_nested_message_fields_kernel but operates on occurrences From 809daa15e679918ec73ab5ad063c2fbdecdf8054 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 14 Mar 2026 17:48:34 +0800 Subject: [PATCH 084/107] nits Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 22 +++-------- src/main/cpp/src/protobuf_builders.cu | 57 ++++++--------------------- src/main/cpp/src/protobuf_common.cuh | 29 ++++++-------- src/main/cpp/src/protobuf_kernels.cu | 36 ++++++++++++++++- 4 files changed, 65 insertions(+), 79 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index b7ca98fe70..735c3e21ba 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -110,9 +110,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { // Repeated message field - build empty LIST with proper struct element rmm::device_uvector offsets(1, stream, mr); - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync( - offsets.data(), &zero, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + CUDF_CUDA_TRY(cudaMemsetAsync(offsets.data(), 0, sizeof(int32_t), stream.value())); auto offsets_col = std::make_unique( cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); auto empty_struct = make_empty_struct_column_with_schema( @@ -199,6 +197,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case ERR_MISSING_ENUM_META: return "Protobuf decode error: missing or mismatched enum metadata for enum-as-string " "field"; + case ERR_REPEATED_COUNT_MISMATCH: + return "Protobuf decode error: repeated-field count/scan mismatch"; default: return "Protobuf decode error: unknown error"; } }; @@ -578,19 +578,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr); } else { // Missing enum metadata for enum-as-string field; mark as decode error. - { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); - } + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); } } else { - { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); - } + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); } } else { @@ -917,9 +909,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream, mr); } else { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); column_map[schema_idx] = make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 7b70b6e115..609301cfa9 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -180,12 +180,8 @@ std::unique_ptr make_empty_column_safe(cudf::data_type dtype, rmm::device_buffer{}, 0); // Initialize offset to 0 - int32_t zero = 0; - CUDF_CUDA_TRY(cudaMemcpyAsync(offsets_col->mutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + CUDF_CUDA_TRY(cudaMemsetAsync( + offsets_col->mutable_view().data(), 0, sizeof(int32_t), stream.value())); auto child_col = std::make_unique( cudf::data_type{cudf::type_id::UINT8}, 0, rmm::device_buffer{}, rmm::device_buffer{}, 0); return cudf::make_lists_column( @@ -237,12 +233,8 @@ std::unique_ptr make_empty_list_column(std::unique_ptrmutable_view().data(), - &zero, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + CUDF_CUDA_TRY(cudaMemsetAsync( + offsets_col->mutable_view().data(), 0, sizeof(int32_t), stream.value())); return cudf::make_lists_column( 0, std::move(offsets_col), std::move(element_col), 0, rmm::device_buffer{}); } @@ -517,8 +509,7 @@ std::unique_ptr build_repeated_enum_string_column( thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), lo.begin(), 0); int32_t tc_i32 = static_cast(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync( - lo.data() + num_rows, &tc_i32, sizeof(int32_t), cudaMemcpyHostToDevice, stream.value())); + thrust::fill_n(rmm::exec_policy(stream), lo.data() + num_rows, 1, tc_i32); auto list_offs_col = std::make_unique( cudf::data_type{cudf::type_id::INT32}, num_rows + 1, lo.release(), rmm::device_buffer{}, 0); @@ -586,11 +577,7 @@ std::unique_ptr build_repeated_string_column( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); int32_t total_count_i32 = static_cast(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count_i32, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_rows, 1, total_count_i32); // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); @@ -784,11 +771,7 @@ std::unique_ptr build_repeated_struct_column( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); int32_t total_count_i32 = static_cast(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count_i32, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_rows, 1, total_count_i32); // Build child field descriptors for scanning within each message occurrence std::vector h_child_descs(num_child_fields); @@ -977,9 +960,7 @@ std::unique_ptr build_repeated_struct_column( stream, mr)); } else { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); struct_children.push_back(make_null_column(dt, total_count, stream, mr)); } } else { @@ -1301,19 +1282,11 @@ std::unique_ptr build_nested_struct_column( top_row_indices, propagate_invalid_rows)); } else { - { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); - } + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { - { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); - } + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { @@ -1560,11 +1533,7 @@ std::unique_ptr build_repeated_child_list_column( rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_parent_rows, - &total_rep_count, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_parent_rows, 1, total_rep_count); rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); scan_repeated_in_nested_kernel<<>>(message_data, @@ -1662,9 +1631,7 @@ std::unique_ptr build_repeated_child_list_column( child_values = build_enum_string_values_column(enum_values, valid, lookup, total_rep_count, stream, mr); } else { - int err_val = ERR_MISSING_ENUM_META; - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_error.data(), &err_val, sizeof(int), cudaMemcpyHostToDevice, stream.value())); + thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); child_values = make_null_column(cudf::data_type{elem_type_id}, total_rep_count, stream, mr); } } else { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 584463eda2..dd37f0107a 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -68,17 +68,18 @@ constexpr int MAX_VARINT_BYTES = 10; constexpr int THREADS_PER_BLOCK = 256; // Error codes for kernel error reporting. -constexpr int ERR_BOUNDS = 1; -constexpr int ERR_VARINT = 2; -constexpr int ERR_FIELD_NUMBER = 3; -constexpr int ERR_WIRE_TYPE = 4; -constexpr int ERR_OVERFLOW = 5; -constexpr int ERR_FIELD_SIZE = 6; -constexpr int ERR_SKIP = 7; -constexpr int ERR_FIXED_LEN = 8; -constexpr int ERR_REQUIRED = 9; -constexpr int ERR_SCHEMA_TOO_LARGE = 10; -constexpr int ERR_MISSING_ENUM_META = 11; +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +constexpr int ERR_SCHEMA_TOO_LARGE = 10; +constexpr int ERR_MISSING_ENUM_META = 11; +constexpr int ERR_REPEATED_COUNT_MISMATCH = 12; // Maximum supported nesting depth for recursive struct decoding. constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; @@ -1741,11 +1742,7 @@ inline std::unique_ptr build_repeated_scalar_column( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); int32_t total_count_i32 = static_cast(total_count); - CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, - &total_count_i32, - sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_rows, 1, total_count_i32); rmm::device_uvector values(total_count, stream, mr); diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 24d4a771db..03a760efe2 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -219,6 +219,7 @@ __device__ bool scan_repeated_element(uint8_t const* cur, int32_t row, repeated_occurrence* occurrences, int& write_idx, + int write_end, int* error_flag) { bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); @@ -252,6 +253,10 @@ __device__ bool scan_repeated_element(uint8_t const* cur, set_error_once(error_flag, ERR_VARINT); return false; } + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } occurrences[write_idx] = {row, elem_offset, vbytes}; write_idx++; p += vbytes; @@ -262,6 +267,10 @@ __device__ bool scan_repeated_element(uint8_t const* cur, return false; } for (uint64_t i = 0; i < packed_len; i += 4) { + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 4}; write_idx++; } @@ -271,6 +280,10 @@ __device__ bool scan_repeated_element(uint8_t const* cur, return false; } for (uint64_t i = 0; i < packed_len; i += 8) { + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 8}; write_idx++; } @@ -281,6 +294,10 @@ __device__ bool scan_repeated_element(uint8_t const* cur, set_error_once(error_flag, ERR_FIELD_SIZE); return false; } + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } int32_t abs_offset = static_cast(cur - msg_base) + data_offset; occurrences[write_idx] = {row, abs_offset, data_length}; write_idx++; @@ -503,6 +520,7 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co static_cast(row), scan_descs[f].occurrences, write_idx[f], + scan_descs[f].row_offsets[row + 1], error_flag); } set_error_once(error_flag, ERR_WIRE_TYPE); @@ -529,6 +547,13 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co } cur = next; } + + for (int f = 0; f < num_scan_fields; f++) { + if (write_idx[f] != scan_descs[f].row_offsets[row + 1]) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return; + } + } } // ============================================================================ @@ -866,8 +891,13 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; + int write_idx = occ_prefix_sums[row]; + int write_end = occ_prefix_sums[row + 1]; auto const& parent_loc = parent_locs[row]; - if (parent_loc.offset < 0) return; + if (parent_loc.offset < 0) { + if (write_idx != write_end) { set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); } + return; + } cudf::size_type row_off = row_offsets[row] - base_offset; @@ -882,7 +912,6 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, uint8_t const* msg_end = msg_start + parent_loc.length; uint8_t const* cur = msg_start; - int write_idx = occ_prefix_sums[row]; int schema_idx = repeated_indices[0]; while (cur < msg_end) { @@ -900,6 +929,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, static_cast(row), occurrences, write_idx, + write_end, error_flag)) { return; } @@ -912,6 +942,8 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, } cur = next; } + + if (write_idx != write_end) { set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); } } /** From 3067f0bf144e9a460f3fc1811284ea7a130f8398 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sat, 14 Mar 2026 22:42:33 +0800 Subject: [PATCH 085/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 76 ++++++++++++++++--- src/main/cpp/src/protobuf_builders.cu | 10 +-- .../nvidia/spark/rapids/jni/ProtobufTest.java | 70 +++++++++++++++++ 3 files changed, 142 insertions(+), 14 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 735c3e21ba..203d3a400f 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -24,6 +24,10 @@ namespace spark_rapids_jni { namespace { +void propagate_nulls_to_descendants(cudf::column& col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + void apply_parent_mask_to_row_aligned_column(cudf::column& col, cudf::bitmask_type const* parent_mask_ptr, cudf::size_type parent_null_count, @@ -53,6 +57,53 @@ void apply_parent_mask_to_row_aligned_column(cudf::column& col, } } +void propagate_list_nulls_to_descendants(cudf::column& list_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (list_col.type().id() != cudf::type_id::LIST || list_col.null_count() == 0) { return; } + + cudf::lists_column_view const list_view(list_col.view()); + auto const* list_mask_ptr = list_view.null_mask(); + auto const num_rows = list_view.size(); + auto& child = list_col.child(cudf::lists_column_view::child_column_index); + auto const child_size = child.size(); + if (child_size == 0) { return; } + + CUDF_EXPECTS(list_view.offset() == 0, + "decoder list null propagation expects unsliced list columns"); + auto const* offsets_begin = list_view.offsets_begin(); + // LIST children are not row-aligned with their parent. Expand the list-row null mask across + // every covered child element so direct access to the backing child column also observes nulls. + auto [element_mask, element_null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(child_size), + [list_mask_ptr, offsets_begin, num_rows] __device__(cudf::size_type idx) { + cudf::size_type lo = 0; + cudf::size_type hi = num_rows; + while (lo < hi) { + auto const mid = lo + (hi - lo) / 2; + if (offsets_begin[mid + 1] <= idx) { + lo = mid + 1; + } else { + hi = mid; + } + } + return list_mask_ptr == nullptr || cudf::bit_is_set(list_mask_ptr, lo); + }, + stream, + mr); + + apply_parent_mask_to_row_aligned_column( + child, + static_cast(element_mask.data()), + element_null_count, + child_size, + stream, + mr); + propagate_nulls_to_descendants(child, stream, mr); +} + void propagate_struct_nulls_to_descendants(cudf::column& struct_col, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -68,9 +119,18 @@ void propagate_struct_nulls_to_descendants(cudf::column& struct_col, auto& child = struct_col.child(i); apply_parent_mask_to_row_aligned_column( child, struct_mask_ptr, null_count, num_rows, stream, mr); - if (child.type().id() == cudf::type_id::STRUCT) { - propagate_struct_nulls_to_descendants(child, stream, mr); - } + propagate_nulls_to_descendants(child, stream, mr); + } +} + +void propagate_nulls_to_descendants(cudf::column& col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (col.type().id()) { + case cudf::type_id::STRUCT: propagate_struct_nulls_to_descendants(col, stream, mr); break; + case cudf::type_id::LIST: propagate_list_nulls_to_descendants(col, stream, mr); break; + default: break; } } @@ -1120,17 +1180,15 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& struct_null_count = null_count; } - // cuDF struct child views do not inherit parent nulls. Push PERMISSIVE invalid-enum nulls - // down into every top-level child, then recursively into nested STRUCT descendants, so - // callers that access grandchildren directly still observe logically-null rows. + // cuDF child views do not automatically inherit parent nulls. Push PERMISSIVE invalid-enum + // nulls down into every top-level child, then recursively through nested STRUCT/LIST children, + // so callers that access backing grandchildren directly still observe logically-null rows. if (has_enum_fields && struct_null_count > 0) { auto const* struct_mask_ptr = static_cast(struct_mask.data()); for (auto& child : top_level_children) { apply_parent_mask_to_row_aligned_column( *child, struct_mask_ptr, struct_null_count, num_rows, stream, mr); - if (child->type().id() == cudf::type_id::STRUCT) { - propagate_struct_nulls_to_descendants(*child, stream, mr); - } + propagate_nulls_to_descendants(*child, stream, mr); } } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 609301cfa9..7c7adb4204 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -895,7 +895,7 @@ std::unique_ptr build_repeated_struct_column( mr, d_top_row_indices.data(), 1, - false)); + true)); continue; } @@ -934,7 +934,7 @@ std::unique_ptr build_repeated_struct_column( stream, mr, d_top_row_indices.data(), - false)); + true)); break; } case cudf::type_id::STRING: { @@ -955,7 +955,7 @@ std::unique_ptr build_repeated_struct_column( enum_names[child_schema_idx], d_row_has_invalid_enum, d_top_row_indices.data(), - false, + true, d_error, stream, mr)); @@ -1043,7 +1043,7 @@ std::unique_ptr build_repeated_struct_column( mr, d_top_row_indices.data(), 0, - false)); + true)); } break; } @@ -1203,7 +1203,7 @@ std::unique_ptr build_nested_struct_column( mr, top_row_indices, depth, - false)); + propagate_invalid_rows)); continue; } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index c7853bd081..7d00c975d3 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -21,6 +21,7 @@ import ai.rapids.cudf.ColumnView; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.HostColumnVectorCore; import ai.rapids.cudf.HostColumnVector.*; import ai.rapids.cudf.Table; import org.junit.jupiter.api.Disabled; @@ -2775,6 +2776,75 @@ void testRepeatedStructEnumInvalidNullsCorrectTopLevelRow() { } } + @Test + void testRepeatedStructEnumInvalidNullsListBackingStructChildren() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Item { Color color = 1; int32 count = 2; } + // message Msg { repeated Item items = 1; } + Byte[] item00 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(0)), + box(tag(2, WT_VARINT)), box(encodeVarint(10))); + Byte[] item01 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(999)), + box(tag(2, WT_VARINT)), box(encodeVarint(20))); + Byte[] row0 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item00.length)), item00, + box(tag(1, WT_LEN)), box(encodeVarint(item01.length)), item01); + Byte[] item10 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(2, WT_VARINT)), box(encodeVarint(30))); + Byte[] row1 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item10.length)), item10); + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, new int[]{0, 1, 2}, null}, + false); + ColumnView itemsView = actual.getChildColumnView(0); + ColumnView itemStructView = itemsView.getChildColumnView(0); + ColumnView countView = itemStructView.getChildColumnView(1); + ColumnVector countVector = countView.copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostCounts = countVector.copyToHost()) { + HostColumnVectorCore hostItems = hostStruct.getChildColumnView(0); + + assertEquals(1, actual.getNullCount(), "Exactly one top-level row should be null"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null because one repeated child enum is invalid"); + assertFalse(hostStruct.isNull(1), "Row 1 should remain valid"); + + assertEquals(1, hostItems.getNullCount(), "LIST row should inherit the top-level null"); + assertTrue(hostItems.isNull(0), "items[0] should be null"); + assertFalse(hostItems.isNull(1), "items[1] should remain valid"); + + assertEquals(1, itemStructView.getRowCount(), + "Direct list child view should not expose stale elements from the null list row"); + assertEquals(0, itemStructView.getNullCount(), + "Direct list child view should be sanitized rather than carrying non-empty nulls"); + assertEquals(1, countView.getRowCount(), + "Direct grandchild view should only expose elements from valid list rows"); + assertEquals(0, countView.getNullCount(), + "Direct grandchild view should also be sanitized"); + assertEquals(30, hostCounts.getInt(0), + "The remaining direct child value should come from the valid list row only"); + } + } + @Test void testEnumMissingFieldDoesNotNullRow() { // Missing enum field should return null for the field, but NOT null the entire row From 8aa0bdf661476f50a117db13294964062a65583a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 15 Mar 2026 16:26:14 +0800 Subject: [PATCH 086/107] style Signed-off-by: Haoyang Li --- .../nvidia/spark/rapids/jni/ProtobufTest.java | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 7d00c975d3..b0d25b6e91 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -64,27 +64,6 @@ private static byte[] encodeVarint(long value) { return out; } - /** - * Encode a varint with extra padding bytes (over-encoded but valid). - * This is useful for testing that parsers accept non-canonical varints. - */ - private static byte[] encodeLongVarint(long value, int extraBytes) { - byte[] tmp = new byte[10]; - int idx = 0; - long v = value; - while ((v & ~0x7FL) != 0 || extraBytes > 0) { - tmp[idx++] = (byte) ((v & 0x7F) | 0x80); - v >>>= 7; - if (v == 0 && extraBytes > 0) { - extraBytes--; - } - } - tmp[idx++] = (byte) (v & 0x7F); - byte[] out = new byte[idx]; - System.arraycopy(tmp, 0, out, 0, idx); - return out; - } - /** ZigZag encode a signed 32-bit integer, returning as unsigned long for varint encoding. */ private static long zigzagEncode32(int n) { return Integer.toUnsignedLong((n << 1) ^ (n >> 31)); From a13433fa859bf2b4693f2223140eb6fb5c51d16d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Sun, 15 Mar 2026 17:06:33 +0800 Subject: [PATCH 087/107] style Signed-off-by: Haoyang Li --- .../rapids/jni/ProtobufSchemaDescriptor.java | 4 ++++ .../jni/ProtobufSchemaDescriptorTest.java | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index c1ee8b5202..7b8dc838e8 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -235,6 +235,10 @@ private static void validate( "Incompatible wire type / output type / encoding at index " + i + ": wireType=" + wt + ", outputTypeId=" + outputTypeIds[i] + ", encoding=" + enc); } + if (isRepeated[i] && isRequired[i]) { + throw new IllegalArgumentException( + "Field at index " + i + " cannot be both repeated and required"); + } if (isRepeated[i] && hasDefaultValue[i]) { throw new IllegalArgumentException( "Repeated field at index " + i + " cannot carry a default value"); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index 12efebe685..3c1ae150eb 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -60,6 +60,27 @@ void testRepeatedFieldCannotCarryDefaultValue() { makeDescriptor(true, true, Protobuf.ENC_DEFAULT, null, null)); } + @Test + void testFieldCannotBeBothRepeatedAndRequired() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{true}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + } + @Test void testEnumStringRequiresEnumMetadata() { assertThrows(IllegalArgumentException.class, () -> From e0cd7f51006d606bfc9935e62ef14f4f9c127245 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Mar 2026 11:01:54 +0800 Subject: [PATCH 088/107] fix bug Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 24 +- src/main/cpp/src/protobuf.hpp | 2 +- src/main/cpp/src/protobuf_builders.cu | 18 +- src/main/cpp/src/protobuf_common.cuh | 20 +- src/main/cpp/src/protobuf_kernels.cu | 37 ++- .../nvidia/spark/rapids/jni/ProtobufTest.java | 245 +++++++++++++----- 6 files changed, 249 insertions(+), 97 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 203d3a400f..b4dcb3f313 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -262,11 +262,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& default: return "Protobuf decode error: unknown error"; } }; - // Enum validation support (PERMISSIVE mode) + // PERMISSIVE-mode row nulling support. Unknown enum values and malformed rows should both + // surface as null structs instead of partially decoded data. bool has_enum_fields = std::any_of( enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); - rmm::device_uvector d_row_has_invalid_enum(has_enum_fields ? num_rows : 0, stream, mr); - if (has_enum_fields) { + bool track_permissive_null_rows = !fail_on_errors; + rmm::device_uvector d_row_has_invalid_enum( + track_permissive_null_rows ? num_rows : 0, stream, mr); + if (track_permissive_null_rows) { CUDF_CUDA_TRY( cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, num_rows * sizeof(bool), stream.value())); } @@ -383,7 +386,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& h_field_lookup.empty() ? nullptr : d_field_lookup.data(), static_cast(h_field_lookup.size()), d_locations.data(), - d_error.data()); + d_error.data(), + track_permissive_null_rows ? d_row_has_invalid_enum.data() : nullptr); // Required-field validation applies to all scalar leaves, not just top-level numerics. maybe_check_required_fields(d_locations.data(), @@ -638,11 +642,11 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr); } else { // Missing enum metadata for enum-as-string field; mark as decode error. - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); } } else { - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); } } else { @@ -969,7 +973,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream, mr); } else { - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); column_map[schema_idx] = make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); } @@ -1126,7 +1130,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr, nullptr, 0, - true); + false); } } @@ -1167,7 +1171,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& cudf::size_type struct_null_count = 0; rmm::device_buffer struct_mask{0, stream, mr}; - if (has_enum_fields) { + if (track_permissive_null_rows) { auto [mask, null_count] = cudf::detail::valid_if( thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), @@ -1183,7 +1187,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // cuDF child views do not automatically inherit parent nulls. Push PERMISSIVE invalid-enum // nulls down into every top-level child, then recursively through nested STRUCT/LIST children, // so callers that access backing grandchildren directly still observe logically-null rows. - if (has_enum_fields && struct_null_count > 0) { + if (track_permissive_null_rows && struct_null_count > 0) { auto const* struct_mask_ptr = static_cast(struct_mask.data()); for (auto& child : top_level_children) { apply_parent_mask_to_row_aligned_column( diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 9df727fd1d..666cc5c845 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -252,7 +252,7 @@ inline ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& c * - FLOAT32 : protobuf `float` (fixed32 wire type) * - FLOAT64 : protobuf `double` (fixed64 wire type) * - STRING : protobuf `string` (length-delimited wire type, UTF-8 text) - * - LIST : protobuf `bytes` (length-delimited wire type, raw bytes as LIST) + * - LIST : protobuf `bytes` (length-delimited wire type, raw bytes as LIST) * - STRUCT : protobuf nested `message` * * @param binary_input LIST column, each row is one protobuf message diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 7c7adb4204..04d9b93a2b 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -504,7 +504,7 @@ std::unique_ptr build_repeated_enum_string_column( auto child_col = build_enum_string_values_column(enum_ints, elem_valid, lookup, total_count, stream, mr); - // 8. Build LIST column with list offsets from per-row counts + // Build the final LIST column from the per-row counts and decoded child strings. rmm::device_uvector lo(num_rows + 1, stream, mr); thrust::exclusive_scan( rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), lo.begin(), 0); @@ -895,7 +895,7 @@ std::unique_ptr build_repeated_struct_column( mr, d_top_row_indices.data(), 1, - true)); + false)); continue; } @@ -934,7 +934,7 @@ std::unique_ptr build_repeated_struct_column( stream, mr, d_top_row_indices.data(), - true)); + false)); break; } case cudf::type_id::STRING: { @@ -955,12 +955,12 @@ std::unique_ptr build_repeated_struct_column( enum_names[child_schema_idx], d_row_has_invalid_enum, d_top_row_indices.data(), - true, + false, d_error, stream, mr)); } else { - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); struct_children.push_back(make_null_column(dt, total_count, stream, mr)); } } else { @@ -1043,7 +1043,7 @@ std::unique_ptr build_repeated_struct_column( mr, d_top_row_indices.data(), 0, - true)); + false)); } break; } @@ -1282,11 +1282,11 @@ std::unique_ptr build_nested_struct_column( top_row_indices, propagate_invalid_rows)); } else { - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); } } else { @@ -1631,7 +1631,7 @@ std::unique_ptr build_repeated_child_list_column( child_values = build_enum_string_values_column(enum_values, valid, lookup, total_rep_count, stream, mr); } else { - thrust::fill_n(rmm::exec_policy(stream), d_error.data(), 1, ERR_MISSING_ENUM_META); + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); child_values = make_null_column(cudf::data_type{elem_type_id}, total_rep_count, stream, mr); } } else { diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index dd37f0107a..c090fbc02a 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -197,6 +197,14 @@ __device__ inline void set_error_once(int* error_flag, int error_code) atomicCAS(error_flag, 0, error_code); } +__global__ void set_error_if_unset_kernel(int* error_flag, int error_code); + +inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) +{ + set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); + CUDF_CUDA_TRY(cudaPeekAtLastError()); +} + __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { switch (wt) { @@ -278,11 +286,10 @@ __device__ inline bool skip_field(uint8_t const* cur, int wt, uint8_t const*& out_cur) { - // End-group is handled by the parent group parser. - if (wt == WT_EGROUP) { - out_cur = cur; - return true; - } + // A bare end-group is only valid while a start-group payload is being parsed recursively inside + // get_wire_type_size(WT_SGROUP). The scan/count kernels should never accept it as a standalone + // field because Spark CPU treats unmatched end-groups as malformed protobuf. + if (wt == WT_EGROUP) { return false; } int size = get_wire_type_size(wt, cur, end); if (size < 0) return false; @@ -1060,7 +1067,8 @@ __global__ void scan_all_fields_kernel(cudf::column_device_view const d_in, int const* field_lookup, int field_lookup_size, field_location* locations, - int* error_flag); + int* error_flag, + bool* row_has_invalid_data); __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, device_nested_field_descriptor const* schema, diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 03a760efe2..3bae207808 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -22,6 +22,11 @@ namespace spark_rapids_jni::protobuf_detail { // Pass 1: Scan all fields kernel - records (offset, length) for each field // ============================================================================ +__global__ void set_error_if_unset_kernel(int* error_flag, int error_code) +{ + if (blockIdx.x == 0 && threadIdx.x == 0) { set_error_once(error_flag, error_code); } +} + /** * Fused scanning kernel: scans each message once and records the location * of all requested fields. @@ -31,9 +36,9 @@ namespace spark_rapids_jni::protobuf_detail { * * If a row hits a parse error that leaves the cursor in an unsafe state (for example, malformed * varint bytes or a schema-matching field with the wrong wire type), the scan aborts for that row - * instead of guessing where the next field begins. In permissive mode this means fields after the - * error position are treated as "not found" and therefore fall back to the usual null/default - * missing-field semantics. + * instead of guessing where the next field begins. In permissive mode the caller may also supply a + * row-level invalidity buffer so the full struct row can be nulled to match Spark CPU semantics for + * malformed messages. */ __global__ void scan_all_fields_kernel( cudf::column_device_view const d_in, @@ -42,12 +47,17 @@ __global__ void scan_all_fields_kernel( int const* field_lookup, // direct-mapped lookup table (nullable) int field_lookup_size, // size of lookup table (0 if null) field_location* locations, // [num_rows * num_fields] row-major - int* error_flag) + int* error_flag, + bool* row_has_invalid_data) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; if (row >= in.size()) return; + auto mark_row_error = [&]() { + if (row_has_invalid_data != nullptr) { row_has_invalid_data[row] = true; } + }; + for (int f = 0; f < num_fields; f++) { locations[flat_index( static_cast(row), static_cast(num_fields), static_cast(f))] = {-1, 0}; @@ -61,14 +71,20 @@ __global__ void scan_all_fields_kernel( int32_t start = in.offset_at(row) - base; int32_t end = in.offset_at(row + 1) - base; - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + if (!check_message_bounds(start, end, child.size(), error_flag)) { + mark_row_error(); + return; + } uint8_t const* cur = bytes + start; uint8_t const* msg_end = bytes + end; while (cur < msg_end) { proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + if (!decode_tag(cur, msg_end, tag, error_flag)) { + mark_row_error(); + return; + } int fn = tag.field_number; int wt = tag.wire_type; @@ -76,6 +92,7 @@ __global__ void scan_all_fields_kernel( if (f >= 0) { if (wt != field_descs[f].expected_wire_type) { set_error_once(error_flag, ERR_WIRE_TYPE); + mark_row_error(); return; } @@ -88,17 +105,20 @@ __global__ void scan_all_fields_kernel( int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { set_error_once(error_flag, ERR_VARINT); + mark_row_error(); return; } if (len > static_cast(msg_end - cur - len_bytes) || len > static_cast(INT_MAX)) { set_error_once(error_flag, ERR_OVERFLOW); + mark_row_error(); return; } // Record offset pointing to the actual data (after length prefix) int32_t data_location; if (!checked_add_int32(data_offset, len_bytes, data_location)) { set_error_once(error_flag, ERR_OVERFLOW); + mark_row_error(); return; } locations[flat_index( @@ -109,6 +129,7 @@ __global__ void scan_all_fields_kernel( int field_size = get_wire_type_size(wt, cur, msg_end); if (field_size < 0) { set_error_once(error_flag, ERR_FIELD_SIZE); + mark_row_error(); return; } locations[flat_index( @@ -121,6 +142,7 @@ __global__ void scan_all_fields_kernel( uint8_t const* next; if (!skip_field(cur, msg_end, wt, next)) { set_error_once(error_flag, ERR_SKIP); + mark_row_error(); return; } cur = next; @@ -614,7 +636,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, if (field_descs[f].is_repeated) { // Repeated children are handled by the dedicated count/scan path, not by // the direct-child location scan used for scalar/nested singleton fields. - continue; + break; } if (wt != field_descs[f].expected_wire_type) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -653,6 +675,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, static_cast(row), static_cast(num_fields), static_cast(f))] = { data_offset, field_size}; } + break; } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index b0d25b6e91..2528ff7a0f 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -247,6 +247,13 @@ private static ColumnVector decodeAllFields(ColumnView binaryInput, new int[numFields][], failOnErrors); } + private static void assertSingleNullStructRow(ColumnVector actual, String message) { + try (HostColumnVector hostStruct = actual.copyToHost()) { + assertEquals(1, actual.getNullCount(), message); + assertTrue(hostStruct.isNull(0), "Row 0 should be null"); + } + } + /** * Helper method for tests with required field support. */ @@ -629,10 +636,7 @@ void testMalformedVarint() { new int[]{DType.INT64.getTypeId().getNativeId()}, new int[]{0}, false)) { - try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Malformed varint should null the struct row"); } } @@ -647,10 +651,7 @@ void testTruncatedVarint() { new int[]{DType.INT64.getTypeId().getNativeId()}, new int[]{0}, false)) { - try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Truncated varint should null the struct row"); } } @@ -665,10 +666,8 @@ void testTruncatedLengthDelimited() { new int[]{DType.STRING.getTypeId().getNativeId()}, new int[]{0}, false)) { - try (ColumnVector expected = ColumnVector.fromStrings((String)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, + "Truncated length-delimited field should null the struct row"); } } @@ -683,10 +682,7 @@ void testTruncatedFixed32() { new int[]{DType.INT32.getTypeId().getNativeId()}, new int[]{Protobuf.ENC_FIXED}, false)) { - try (ColumnVector expected = ColumnVector.fromBoxedInts((Integer)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Truncated fixed32 should null the struct row"); } } @@ -702,10 +698,7 @@ void testTruncatedFixed64() { new int[]{DType.INT64.getTypeId().getNativeId()}, new int[]{Protobuf.ENC_FIXED}, false)) { - try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Truncated fixed64 should null the struct row"); } } @@ -723,10 +716,8 @@ void testPartialLengthDelimitedData() { new int[]{DType.STRING.getTypeId().getNativeId()}, new int[]{0}, false)) { - try (ColumnVector expected = ColumnVector.fromStrings((String)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, + "Partial length-delimited payload should null the struct row"); } } @@ -747,10 +738,7 @@ void testWrongWireType() { new int[]{DType.INT64.getTypeId().getNativeId()}, // expects varint new int[]{Protobuf.ENC_DEFAULT}, false)) { - try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Wrong wire type should null the struct row"); } } @@ -767,10 +755,7 @@ void testWrongWireTypeForString() { new int[]{DType.STRING.getTypeId().getNativeId()}, // expects LEN new int[]{Protobuf.ENC_DEFAULT}, false)) { - try (ColumnVector expected = ColumnVector.fromStrings((String)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Wrong wire type for string should null the struct row"); } } @@ -940,10 +925,7 @@ void testFieldNumberZeroInvalid() { new int[]{DType.INT64.getTypeId().getNativeId()}, new int[]{0}, false)) { - try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); - } + assertSingleNullStructRow(result, "Field number zero should null the struct row"); } } @@ -2417,20 +2399,18 @@ void testFailfastFieldNumberAboveSpecLimit() { } @Test - void testUnknownEndGroupWireTypeDoesNotAbortDecode() { + void testUnknownEndGroupWireTypeNullsMalformedRow() { Byte[] row = concat( box(tag(5, 4)), box(tag(1, WT_VARINT)), box(encodeVarint(42))); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector expected = ColumnVector.fromBoxedLongs(42L); - ColumnVector expectedStruct = ColumnVector.makeStruct(expected); ColumnVector actual = decodeAllFields( input.getColumn(0), new int[]{1}, new int[]{DType.INT64.getTypeId().getNativeId()}, new int[]{Protobuf.ENC_DEFAULT}, false)) { - AssertUtils.assertStructColumnsAreEqual(expectedStruct, actual); + assertSingleNullStructRow(actual, "Unknown end-group wire type should null the struct row"); } } @@ -2717,7 +2697,7 @@ void testEnumWithOtherFields_NullsEntireRow() { } @Test - void testRepeatedStructEnumInvalidNullsCorrectTopLevelRow() { + void testRepeatedStructEnumInvalidKeepsTopLevelRowValid() { // enum Color { RED=0; GREEN=1; BLUE=2; } // message Item { Color color = 1; } // message Msg { repeated Item items = 1; } @@ -2748,15 +2728,24 @@ void testRepeatedStructEnumInvalidNullsCorrectTopLevelRow() { new byte[][]{null, null}, new int[][]{null, new int[]{0, 1, 2}}, false); - HostColumnVector hostStruct = actualStruct.copyToHost()) { - assertEquals(1, actualStruct.getNullCount(), "Exactly one top-level row should be null"); - assertTrue(hostStruct.isNull(0), "Row 0 should be null because one repeated child enum is invalid"); + ColumnVector items = actualStruct.getChildColumnView(0).copyToColumnVector(); + ColumnVector itemStructs = items.getChildColumnView(0).copyToColumnVector(); + ColumnVector colors = itemStructs.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actualStruct.copyToHost(); + HostColumnVector hostColors = colors.copyToHost()) { + assertEquals(0, actualStruct.getNullCount(), "Invalid child enum should not null the top-level row"); + assertFalse(hostStruct.isNull(0), "Row 0 should remain valid"); assertFalse(hostStruct.isNull(1), "Row 1 should remain valid"); + assertEquals(3, colors.getRowCount(), "All repeated message elements should remain visible"); + assertEquals(1, colors.getNullCount(), "Only the invalid enum field should be null"); + assertEquals(0, hostColors.getInt(0), "The first repeated child should keep its valid enum"); + assertTrue(hostColors.isNull(1), "The invalid repeated child enum should decode as null"); + assertEquals(1, hostColors.getInt(2), "The second row should keep its valid enum"); } } @Test - void testRepeatedStructEnumInvalidNullsListBackingStructChildren() { + void testRepeatedStructEnumInvalidKeepsSiblingFieldsVisible() { // enum Color { RED=0; GREEN=1; BLUE=2; } // message Item { Color color = 1; int32 count = 2; } // message Msg { repeated Item items = 1; } @@ -2797,30 +2786,42 @@ void testRepeatedStructEnumInvalidNullsListBackingStructChildren() { false); ColumnView itemsView = actual.getChildColumnView(0); ColumnView itemStructView = itemsView.getChildColumnView(0); + ColumnView colorView = itemStructView.getChildColumnView(0); ColumnView countView = itemStructView.getChildColumnView(1); + ColumnVector colorVector = colorView.copyToColumnVector(); ColumnVector countVector = countView.copyToColumnVector(); HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostColors = colorVector.copyToHost(); HostColumnVector hostCounts = countVector.copyToHost()) { HostColumnVectorCore hostItems = hostStruct.getChildColumnView(0); - assertEquals(1, actual.getNullCount(), "Exactly one top-level row should be null"); - assertTrue(hostStruct.isNull(0), "Row 0 should be null because one repeated child enum is invalid"); + assertEquals(0, actual.getNullCount(), "Invalid child enum should not null the parent row"); + assertFalse(hostStruct.isNull(0), "Row 0 should remain valid"); assertFalse(hostStruct.isNull(1), "Row 1 should remain valid"); - assertEquals(1, hostItems.getNullCount(), "LIST row should inherit the top-level null"); - assertTrue(hostItems.isNull(0), "items[0] should be null"); + assertEquals(0, hostItems.getNullCount(), "LIST rows should remain valid"); + assertFalse(hostItems.isNull(0), "items[0] should remain valid"); assertFalse(hostItems.isNull(1), "items[1] should remain valid"); - assertEquals(1, itemStructView.getRowCount(), - "Direct list child view should not expose stale elements from the null list row"); + assertEquals(3, itemStructView.getRowCount(), + "All repeated message elements should remain visible"); assertEquals(0, itemStructView.getNullCount(), - "Direct list child view should be sanitized rather than carrying non-empty nulls"); - assertEquals(1, countView.getRowCount(), - "Direct grandchild view should only expose elements from valid list rows"); + "No repeated struct element should be dropped"); + assertEquals(1, colorView.getNullCount(), + "Only the invalid enum child should be null"); + assertEquals(0, hostColors.getInt(0), + "The first repeated child should keep its valid enum"); + assertTrue(hostColors.isNull(1), + "The invalid repeated child enum should decode as null"); + assertEquals(1, hostColors.getInt(2), + "The second row should keep its valid enum"); + assertEquals(3, countView.getRowCount(), + "Sibling fields should remain visible for every repeated element"); assertEquals(0, countView.getNullCount(), - "Direct grandchild view should also be sanitized"); - assertEquals(30, hostCounts.getInt(0), - "The remaining direct child value should come from the valid list row only"); + "Sibling scalar fields should stay non-null when only the enum is invalid"); + assertEquals(10, hostCounts.getInt(0)); + assertEquals(20, hostCounts.getInt(1)); + assertEquals(30, hostCounts.getInt(2)); } } @@ -2847,7 +2848,7 @@ void testEnumMissingFieldDoesNotNullRow() { } @Test - void testNestedEnumInvalidNullsGrandchildFieldInPermissiveMode() { + void testNestedEnumInvalidKeepsRowAndSiblingFieldsInPermissiveMode() { // message WithNestedEnum { // optional int32 id = 1; // optional Detail detail = 2; @@ -2907,16 +2908,82 @@ void testNestedEnumInvalidNullsGrandchildFieldInPermissiveMode() { enumNames, false); ColumnVector detailCol = actual.getChildColumnView(1).copyToColumnVector(); + ColumnVector statusCol = detailCol.getChildColumnView(0).copyToColumnVector(); ColumnVector countCol = detailCol.getChildColumnView(1).copyToColumnVector(); HostColumnVector hostStruct = actual.copyToHost(); HostColumnVector hostDetail = detailCol.copyToHost(); + HostColumnVector hostStatus = statusCol.copyToHost(); HostColumnVector hostCount = countCol.copyToHost()) { - assertEquals(1, actual.getNullCount(), "Top-level row should be null"); - assertTrue(hostStruct.isNull(0), "Top-level struct should be null"); - assertEquals(1, detailCol.getNullCount(), "Nested struct child should be null after mask pushdown"); - assertTrue(hostDetail.isNull(0), "Nested struct child row should be null"); - assertEquals(1, countCol.getNullCount(), "Grandchild field should also be null"); - assertTrue(hostCount.isNull(0), "detail.count should be null when parent row is null"); + assertEquals(0, actual.getNullCount(), "Invalid nested enum should not null the top-level row"); + assertFalse(hostStruct.isNull(0), "Top-level struct should remain valid"); + assertEquals(0, detailCol.getNullCount(), "Nested struct should remain present"); + assertFalse(hostDetail.isNull(0), "Nested struct row should remain valid"); + assertEquals(1, statusCol.getNullCount(), "Only the invalid enum field should be null"); + assertTrue(hostStatus.isNull(0), "detail.status should decode as null"); + assertEquals(0, countCol.getNullCount(), "Sibling nested fields should remain visible"); + assertFalse(hostCount.isNull(0), "detail.count should remain valid"); + assertEquals(20, hostCount.getInt(0), "detail.count should preserve the decoded value"); + } + } + + @Test + void testMalformedNestedEnumPermissiveNullsWholeRow() { + // message WithNestedEnum { + // optional int32 id = 1; + // optional Detail detail = 2; + // optional string name = 3; + // } + // message Detail { + // enum Status { UNKNOWN = 0; OK = 1; BAD = 2; } + // optional Status status = 1; + // optional int32 count = 2; + // } + // + // The nested message length is intentionally truncated to 4 bytes. Spark CPU treats this as a + // malformed row in PERMISSIVE mode and returns a null struct row rather than partial data. + Byte[] rowValid = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(2, WT_LEN)), box(encodeVarint(4)), + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(2, WT_VARINT)), box(encodeVarint(10)), + box(tag(3, WT_LEN)), box(encodeVarint(2)), box("ok".getBytes())); + Byte[] rowInvalid = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(2, WT_LEN)), box(encodeVarint(4)), + box(tag(1, WT_VARINT)), box(encodeVarint(999)), + box(tag(2, WT_VARINT)), box(encodeVarint(20)), + box(tag(3, WT_LEN)), box(encodeVarint(3)), box("bad".getBytes())); + + try (Table input = new Table.TestBuilder().column(rowValid, rowInvalid).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 2, 3, 1, 2}, + new int[]{-1, -1, -1, 1, 1}, + new int[]{0, 0, 0, 1, 1}, + new int[]{WT_VARINT, WT_LEN, WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.STRUCT.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new long[]{0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false}, + new byte[][]{null, null, null, null, null}, + new int[][]{null, null, null, new int[]{0, 1, 2}, null}, + false); + HostColumnVector hostStruct = actual.copyToHost()) { + assertEquals(1, actual.getNullCount(), "Only the malformed row should be null"); + assertFalse(hostStruct.isNull(0), "The valid row should remain decoded"); + assertTrue(hostStruct.isNull(1), "The malformed nested row should be null in PERMISSIVE mode"); } } @@ -3049,6 +3116,56 @@ void testRepeatedMessageChildEnumAsString() { } } + @Test + void testRepeatedMessageChildEnumAsStringInvalidKeepsRowValid() { + Byte[] item0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // FOO + Byte[] item1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // invalid + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item0.length)), item0, + box(tag(1, WT_LEN)), box(encodeVarint(item1.length)), item1); + + byte[][][] enumNames = new byte[][][] { + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "FOO".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAR".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_ENUM_STRING}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + enumNames, + false); + ColumnVector items = actual.getChildColumnView(0).copyToColumnVector(); + ColumnVector itemStructs = items.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorities = itemStructs.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostPriorities = priorities.copyToHost()) { + assertEquals(0, actual.getNullCount(), "Invalid child enum should not null the top-level row"); + assertFalse(hostStruct.isNull(0), "The top-level row should remain valid"); + assertEquals(2, priorities.getRowCount(), "Both repeated message elements should remain visible"); + assertEquals(1, priorities.getNullCount(), "Only the invalid enum field should be null"); + assertEquals("FOO", hostPriorities.getJavaString(0)); + assertTrue(hostPriorities.isNull(1), "The invalid repeated child enum should decode as null"); + } + } + @Test void testNestedRepeatedEnumAsString() { // message Inner { repeated Priority priority = 1; } From 5ef84e9c1b81e0d23b933457bc0458a5b6c98399 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Mar 2026 11:53:59 +0800 Subject: [PATCH 089/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 27 +++++------- src/main/cpp/src/protobuf.hpp | 6 +++ src/main/cpp/src/protobuf_builders.cu | 44 +++++++++---------- src/main/cpp/src/protobuf_common.cuh | 14 +++--- .../rapids/jni/ProtobufSchemaDescriptor.java | 5 +++ .../jni/ProtobufSchemaDescriptorTest.java | 42 ++++++++++++++++++ 6 files changed, 92 insertions(+), 46 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index b4dcb3f313..279a2e54ea 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -267,11 +267,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& bool has_enum_fields = std::any_of( enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); bool track_permissive_null_rows = !fail_on_errors; - rmm::device_uvector d_row_has_invalid_enum( - track_permissive_null_rows ? num_rows : 0, stream, mr); + rmm::device_uvector d_row_force_null(track_permissive_null_rows ? num_rows : 0, stream, mr); if (track_permissive_null_rows) { CUDF_CUDA_TRY( - cudaMemsetAsync(d_row_has_invalid_enum.data(), 0, num_rows * sizeof(bool), stream.value())); + cudaMemsetAsync(d_row_force_null.data(), 0, num_rows * sizeof(bool), stream.value())); } auto const threads = THREADS_PER_BLOCK; @@ -387,7 +386,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& static_cast(h_field_lookup.size()), d_locations.data(), d_error.data(), - track_permissive_null_rows ? d_row_has_invalid_enum.data() : nullptr); + track_permissive_null_rows ? d_row_force_null.data() : nullptr); // Required-field validation applies to all scalar leaves, not just top-level numerics. maybe_check_required_fields(d_locations.data(), @@ -589,7 +588,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& schema_idx, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr); @@ -632,14 +631,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto const& valid_enums = enum_valid_values[schema_idx]; auto const& enum_name_bytes = enum_names[schema_idx]; if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { - column_map[schema_idx] = build_enum_string_column(out, - valid, - valid_enums, - enum_name_bytes, - d_row_has_invalid_enum, - num_rows, - stream, - mr); + column_map[schema_idx] = build_enum_string_column( + out, valid, valid_enums, enum_name_bytes, d_row_force_null, num_rows, stream, mr); } else { // Missing enum metadata for enum-as-string field; mark as decode error. set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); @@ -968,7 +961,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_rows, field_meta.enum_valid_values, field_meta.enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr); @@ -1038,7 +1031,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& schema, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr); @@ -1123,7 +1116,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, num_rows, stream, @@ -1175,7 +1168,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto [mask, null_count] = cudf::detail::valid_if( thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), - [row_invalid = d_row_has_invalid_enum.data()] __device__(cudf::size_type row) { + [row_invalid = d_row_force_null.data()] __device__(cudf::size_type row) { return !row_invalid[row]; }, stream, diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index 666cc5c845..f6cf5e5ea6 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -204,6 +204,12 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) "protobuf decode context: repeated field cannot carry default value at field " + std::to_string(i)); } + if (field.has_default_value && + (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)) { + throw std::invalid_argument( + "protobuf decode context: STRUCT/LIST field cannot carry default value at field " + + std::to_string(i)); + } if (!is_encoding_compatible(field, type)) { throw std::invalid_argument( "protobuf decode context: incompatible wire type/encoding/output type at field " + diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 04d9b93a2b..75f28d6f2a 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -343,7 +343,7 @@ std::unique_ptr build_enum_string_column( rmm::device_uvector& valid, std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, @@ -367,7 +367,7 @@ std::unique_ptr build_enum_string_column( static_cast(valid_enums.size()), num_rows); propagate_invalid_enum_flags_to_rows(d_item_has_invalid_enum, - d_row_has_invalid_enum, + d_row_force_null, num_rows, top_row_indices, propagate_invalid_rows, @@ -386,7 +386,7 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column int total_count, std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, int32_t const* top_row_indices, bool propagate_invalid_rows, rmm::device_uvector& d_error, @@ -428,7 +428,7 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column static_cast(valid_enums.size()), total_count); propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, - d_row_has_invalid_enum, + d_row_force_null, total_count, top_row_indices, propagate_invalid_rows, @@ -448,7 +448,7 @@ std::unique_ptr build_repeated_enum_string_column( int num_rows, std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -494,7 +494,7 @@ std::unique_ptr build_repeated_enum_string_column( d_top_row_indices.begin(), [] __device__(repeated_occurrence const& occ) { return occ.row_idx; }); propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, - d_row_has_invalid_enum, + d_row_force_null, total_count, d_top_row_indices.data(), true, @@ -657,7 +657,7 @@ std::unique_ptr build_nested_struct_column( std::vector> const& default_strings, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, int num_rows, rmm::cuda_stream_view stream, @@ -686,7 +686,7 @@ std::unique_ptr build_repeated_child_list_column( std::vector> const& default_strings, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, @@ -715,7 +715,7 @@ std::unique_ptr build_repeated_struct_column( std::vector const& schema, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error_top, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -889,7 +889,7 @@ std::unique_ptr build_repeated_struct_column( default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error_top, stream, mr, @@ -929,7 +929,7 @@ std::unique_ptr build_repeated_struct_column( child_schema_idx, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr, @@ -953,7 +953,7 @@ std::unique_ptr build_repeated_struct_column( total_count, enum_valid_values[child_schema_idx], enum_names[child_schema_idx], - d_row_has_invalid_enum, + d_row_force_null, d_top_row_indices.data(), false, d_error, @@ -1036,7 +1036,7 @@ std::unique_ptr build_repeated_struct_column( default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error_top, total_count, stream, @@ -1095,7 +1095,7 @@ std::unique_ptr build_nested_struct_column( std::vector> const& default_strings, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, int num_rows, rmm::cuda_stream_view stream, @@ -1197,7 +1197,7 @@ std::unique_ptr build_nested_struct_column( default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr, @@ -1237,7 +1237,7 @@ std::unique_ptr build_nested_struct_column( child_schema_idx, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr, @@ -1275,7 +1275,7 @@ std::unique_ptr build_nested_struct_column( valid, valid_enums, enum_name_bytes, - d_row_has_invalid_enum, + d_row_force_null, num_rows, stream, mr, @@ -1400,7 +1400,7 @@ std::unique_ptr build_nested_struct_column( default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, num_rows, stream, @@ -1447,7 +1447,7 @@ std::unique_ptr build_repeated_child_list_column( std::vector> const& default_strings, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, @@ -1582,7 +1582,7 @@ std::unique_ptr build_repeated_child_list_column( child_schema_idx, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, stream, mr, @@ -1622,7 +1622,7 @@ std::unique_ptr build_repeated_child_list_column( static_cast(lookup.d_valid_enums.size()), total_rep_count); propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, - d_row_has_invalid_enum, + d_row_force_null, total_rep_count, d_rep_top_row_indices.data(), propagate_invalid_rows, @@ -1690,7 +1690,7 @@ std::unique_ptr build_repeated_child_list_column( default_strings, enum_valid_values, enum_names, - d_row_has_invalid_enum, + d_row_force_null, d_error, total_rep_count, stream, diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index c090fbc02a..e715f37054 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1356,7 +1356,7 @@ std::unique_ptr build_enum_string_column( rmm::device_uvector& valid, std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, int num_rows, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, @@ -1375,7 +1375,7 @@ std::unique_ptr build_repeated_enum_string_column( int num_rows, std::vector const& valid_enums, std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); @@ -1411,7 +1411,7 @@ std::unique_ptr build_nested_struct_column( std::vector> const& default_strings, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, int num_rows, rmm::cuda_stream_view stream, @@ -1437,7 +1437,7 @@ std::unique_ptr build_repeated_child_list_column( std::vector> const& default_strings, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, @@ -1466,7 +1466,7 @@ std::unique_ptr build_repeated_struct_column( std::vector const& schema, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error_top, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); @@ -1554,7 +1554,7 @@ inline std::unique_ptr extract_typed_column( int schema_idx, std::vector> const& enum_valid_values, std::vector>> const& enum_names, - rmm::device_uvector& d_row_has_invalid_enum, + rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr, @@ -1603,7 +1603,7 @@ inline std::unique_ptr extract_typed_column( validate_enum_and_propagate_rows(out, valid, valid_enums, - d_row_has_invalid_enum, + d_row_force_null, num_items, top_row_indices, propagate_invalid_rows, diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 7b8dc838e8..32efc9105e 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -243,6 +243,11 @@ private static void validate( throw new IllegalArgumentException( "Repeated field at index " + i + " cannot carry a default value"); } + if (hasDefaultValue[i] && + (outputTypeIds[i] == STRUCT_TYPE_ID || outputTypeIds[i] == LIST_TYPE_ID)) { + throw new IllegalArgumentException( + "STRUCT/LIST field at index " + i + " cannot carry a default value"); + } if (enc == Protobuf.ENC_ENUM_STRING && (enumValidValues[i] == null || enumNames[i] == null)) { throw new IllegalArgumentException( diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index 3c1ae150eb..21a20109bb 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -81,6 +81,48 @@ void testFieldCannotBeBothRepeatedAndRequired() { new byte[][][]{null})); } + @Test + void testStructFieldCannotCarryDefaultValue() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_LEN}, + new int[]{ai.rapids.cudf.DType.STRUCT.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + } + + @Test + void testListFieldCannotCarryDefaultValue() { + assertThrows(IllegalArgumentException.class, () -> + new ProtobufSchemaDescriptor( + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_LEN}, + new int[]{ai.rapids.cudf.DType.LIST.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + new byte[][][]{null})); + } + @Test void testEnumStringRequiresEnumMetadata() { assertThrows(IllegalArgumentException.class, () -> From 4817d06813fd8b1785970a30726ab6cbe7264ed5 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Mar 2026 13:23:10 +0800 Subject: [PATCH 090/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 5 +++++ src/main/cpp/src/protobuf.cu | 20 ++++++++------------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index eac0cce59f..a77fced341 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -107,6 +107,11 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, auto const depth = n_depth_levels[i]; auto const wire_type = n_wire_types[i]; + if (n_field_numbers[i] <= 0) { + JNI_THROW_NEW( + env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "field_numbers must be positive", 0); + } + if (!(wire_type == 0 || wire_type == 1 || wire_type == 2 || wire_type == 5)) { JNI_THROW_NEW( env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "wire_types must be one of {0,1,2,5}", 0); diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 279a2e54ea..59085276dd 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -73,23 +73,16 @@ void propagate_list_nulls_to_descendants(cudf::column& list_col, CUDF_EXPECTS(list_view.offset() == 0, "decoder list null propagation expects unsliced list columns"); auto const* offsets_begin = list_view.offsets_begin(); + auto const* offsets_end = list_view.offsets_end(); // LIST children are not row-aligned with their parent. Expand the list-row null mask across // every covered child element so direct access to the backing child column also observes nulls. auto [element_mask, element_null_count] = cudf::detail::valid_if( thrust::make_counting_iterator(0), thrust::make_counting_iterator(child_size), - [list_mask_ptr, offsets_begin, num_rows] __device__(cudf::size_type idx) { - cudf::size_type lo = 0; - cudf::size_type hi = num_rows; - while (lo < hi) { - auto const mid = lo + (hi - lo) / 2; - if (offsets_begin[mid + 1] <= idx) { - lo = mid + 1; - } else { - hi = mid; - } - } - return list_mask_ptr == nullptr || cudf::bit_is_set(list_mask_ptr, lo); + [list_mask_ptr, offsets_begin, offsets_end] __device__(cudf::size_type idx) { + auto const it = thrust::upper_bound(thrust::seq, offsets_begin, offsets_end, idx); + auto const row = static_cast(it - offsets_begin) - 1; + return list_mask_ptr == nullptr || cudf::bit_is_set(list_mask_ptr, row); }, stream, mr); @@ -1158,6 +1151,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& CUDF_CUDA_TRY( cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); stream.synchronize(); + if (h_error == ERR_SCHEMA_TOO_LARGE || h_error == ERR_REPEATED_COUNT_MISMATCH) { + throw cudf::logic_error(error_message(h_error)); + } if (fail_on_errors && h_error != 0) { throw cudf::logic_error(error_message(h_error)); } // Build final struct with PERMISSIVE mode null mask for invalid enums From bdd2e40c2c992dd287c3201d53efc2e29b6c5a24 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 16 Mar 2026 15:27:07 +0800 Subject: [PATCH 091/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf.cu | 4 ++ src/main/cpp/src/protobuf.hpp | 15 ++++-- src/main/cpp/src/protobuf_builders.cu | 4 ++ src/main/cpp/src/protobuf_common.cuh | 6 +++ src/main/cpp/src/protobuf_kernels.cu | 7 +++ .../rapids/jni/ProtobufSchemaDescriptor.java | 5 +- .../jni/ProtobufSchemaDescriptorTest.java | 6 +++ .../nvidia/spark/rapids/jni/ProtobufTest.java | 46 +++++++++++++++++-- 8 files changed, 82 insertions(+), 11 deletions(-) diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 59085276dd..7892e4fdee 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -389,6 +389,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input.null_count() > 0 ? binary_input.null_mask() : nullptr, binary_input.offset(), nullptr, + track_permissive_null_rows ? d_row_force_null.data() : nullptr, + nullptr, d_error.data(), stream, mr); @@ -711,6 +713,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input.null_count() > 0 ? binary_input.null_mask() : nullptr, binary_input.offset(), nullptr, + track_permissive_null_rows ? d_row_force_null.data() : nullptr, + nullptr, d_error.data(), stream, mr); diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index f6cf5e5ea6..c574c7cfaa 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -215,10 +215,17 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) "protobuf decode context: incompatible wire type/encoding/output type at field " + std::to_string(i)); } - if (field.encoding == ENC_ENUM_STRING && - context.enum_valid_values[i].size() != context.enum_names[i].size()) { - throw std::invalid_argument( - "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i)); + if (field.encoding == ENC_ENUM_STRING) { + if (context.enum_valid_values[i].empty() || context.enum_names[i].empty()) { + throw std::invalid_argument( + "protobuf decode context: enum-as-string field requires non-empty metadata at field " + + std::to_string(i)); + } + if (context.enum_valid_values[i].size() != context.enum_names[i].size()) { + throw std::invalid_argument( + "protobuf decode context: enum-as-string metadata mismatch at field " + + std::to_string(i)); + } } } } diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 75f28d6f2a..113502a7d2 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -853,6 +853,8 @@ std::unique_ptr build_repeated_struct_column( nullptr, 0, nullptr, + d_row_force_null.size() > 0 ? d_row_force_null.data() : nullptr, + d_top_row_indices.data(), d_error.data(), stream, mr); @@ -1168,6 +1170,8 @@ std::unique_ptr build_nested_struct_column( nullptr, 0, d_parent_locs.data(), + d_row_force_null.size() > 0 ? d_row_force_null.data() : nullptr, + top_row_indices, d_error.data(), stream, mr); diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index e715f37054..2866cc1c08 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -1190,6 +1190,8 @@ __global__ void check_required_fields_kernel(field_location const* locations, cudf::bitmask_type const* input_null_mask, cudf::size_type input_offset, field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, int* error_flag); inline void maybe_check_required_fields(field_location const* locations, @@ -1199,6 +1201,8 @@ inline void maybe_check_required_fields(field_location const* locations, cudf::bitmask_type const* input_null_mask, cudf::size_type input_offset, field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, int* error_flag, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -1229,6 +1233,8 @@ inline void maybe_check_required_fields(field_location const* locations, input_null_mask, input_offset, parent_locs, + row_force_null, + top_row_indices, error_flag); } diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index 3bae207808..fa3cc54999 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -1138,6 +1138,8 @@ __global__ void check_required_fields_kernel( cudf::bitmask_type const* input_null_mask, // optional top-level input null mask cudf::size_type input_offset, // bit offset for sliced top-level input field_location const* parent_locs, // [num_rows] optional parent presence for nested rows + bool* row_force_null, // [top_level_num_rows] optional permissive row nulling + int32_t const* top_row_indices, // [num_rows] optional nested-row -> top-row mapping int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); @@ -1152,6 +1154,11 @@ __global__ void check_required_fields_kernel( static_cast(num_fields), static_cast(f))] .offset < 0) { + if (row_force_null != nullptr) { + auto const top_row = + top_row_indices != nullptr ? top_row_indices[row] : static_cast(row); + row_force_null[top_row] = true; + } // Required field is missing - set error flag set_error_once(error_flag, ERR_REQUIRED); return; // No need to check other fields for this row diff --git a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java index 32efc9105e..810cfb60a9 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java @@ -249,10 +249,11 @@ private static void validate( "STRUCT/LIST field at index " + i + " cannot carry a default value"); } if (enc == Protobuf.ENC_ENUM_STRING && - (enumValidValues[i] == null || enumNames[i] == null)) { + (enumValidValues[i] == null || enumValidValues[i].length == 0 || + enumNames[i] == null || enumNames[i].length == 0)) { throw new IllegalArgumentException( "Enum-as-string field at index " + i + - " must provide both enumValidValues and enumNames"); + " must provide non-empty enumValidValues and enumNames"); } if (enumValidValues[i] != null) { int[] ev = enumValidValues[i]; diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index 21a20109bb..bf56c2662b 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -134,6 +134,12 @@ void testEnumStringRequiresEnumMetadata() { new byte[][]{"A".getBytes(), "B".getBytes()})); } + @Test + void testEnumStringRejectsEmptyEnumArrays() { + assertThrows(IllegalArgumentException.class, () -> + makeDescriptor(false, false, Protobuf.ENC_ENUM_STRING, new int[]{}, new byte[][]{})); + } + @Test void testDuplicateFieldNumbersUnderSameParentRejected() { assertThrows(IllegalArgumentException.class, () -> diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 2528ff7a0f..b013a65bee 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1143,16 +1143,13 @@ void testRequiredFieldPresent() { @Test void testRequiredFieldMissing_Permissive() { - // Required field missing in permissive mode - should return null without exception + // Required field missing in permissive mode - should null the whole row without exception // message Msg { required int64 id = 1; optional string name = 2; } // Only name field present, required id is missing Byte[] row = concat( box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector expectedId = ColumnVector.fromBoxedLongs((Long) null); - ColumnVector expectedName = ColumnVector.fromStrings("hello"); - ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); ColumnVector actualStruct = decodeAllFieldsWithRequired( input.getColumn(0), new int[]{1, 2}, @@ -1160,7 +1157,8 @@ void testRequiredFieldMissing_Permissive() { new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, new boolean[]{true, false}, // id is required, name is optional false)) { // permissive mode - don't fail on errors - AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + assertSingleNullStructRow(actualStruct, + "Missing top-level required field should null the row in PERMISSIVE mode"); } } @@ -1369,6 +1367,44 @@ void testRequiredFieldInsideNestedMessageMissing_Failfast() { } } + @Test + void testRequiredFieldInsideNestedMessageMissing_Permissive() { + // message Outer { optional Inner detail = 1; optional string name = 2; } + // message Inner { required int32 id = 1; optional string note = 2; } + // If detail is present but nested required id is missing, PERMISSIVE should null the row. + Byte[] inner = concat( + box(tag(2, WT_LEN)), box(encodeVarint(4)), box("oops".getBytes())); + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(inner.length)), inner, + box(tag(2, WT_LEN)), box(encodeVarint(7)), box("outside".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2, 2}, + new int[]{-1, 0, 0, -1}, + new int[]{0, 1, 1, 0}, + new int[]{WT_LEN, WT_VARINT, WT_LEN, WT_LEN}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false}, + new boolean[]{false, true, false, false}, + new boolean[]{false, false, false, false}, + new long[]{0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false}, + new byte[][]{null, null, null, null}, + new int[][]{null, null, null, null}, + false)) { + assertSingleNullStructRow(actual, + "Missing nested required field should null the outer row in PERMISSIVE mode"); + } + } + @Test void testAbsentNestedParentSkipsRequiredChildCheck_Failfast() { // message Outer { optional Inner detail = 1; } From 6101a4b0f12bb3e353a7d5cbce643ab349e38ef0 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 17 Mar 2026 11:08:05 +0800 Subject: [PATCH 092/107] address human comments Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 79 ++++++++++----- src/main/cpp/src/ProtobufJni.cpp | 15 ++- src/main/cpp/src/protobuf.cu | 27 ++++-- src/main/cpp/src/protobuf.hpp | 90 ++++++++++++----- src/main/cpp/src/protobuf_builders.cu | 9 +- src/main/cpp/src/protobuf_common.cuh | 107 +++++++++++++-------- src/main/cpp/src/protobuf_kernels.cu | 48 +++++---- 7 files changed, 254 insertions(+), 121 deletions(-) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 511e47d9cd..b54263a759 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -55,13 +55,17 @@ void encode_tag(std::vector& buf, int field_number, int wire_type) void encode_varint_field(std::vector& buf, int field_number, int64_t value) { - encode_tag(buf, field_number, /*WT_VARINT=*/0); + encode_tag(buf, + field_number, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)); encode_varint(buf, static_cast(value)); } void encode_fixed32_field(std::vector& buf, int field_number, float value) { - encode_tag(buf, field_number, /*WT_32BIT=*/5); + encode_tag(buf, + field_number, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)); uint32_t bits; std::memcpy(&bits, &value, sizeof(bits)); for (int i = 0; i < 4; i++) { @@ -72,7 +76,9 @@ void encode_fixed32_field(std::vector& buf, int field_number, float val void encode_fixed64_field(std::vector& buf, int field_number, double value) { - encode_tag(buf, field_number, /*WT_64BIT=*/1); + encode_tag(buf, + field_number, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)); uint64_t bits; std::memcpy(&bits, &value, sizeof(bits)); for (int i = 0; i < 8; i++) { @@ -83,7 +89,8 @@ void encode_fixed64_field(std::vector& buf, int field_number, double va void encode_len_field(std::vector& buf, int field_number, void const* data, size_t len) { - encode_tag(buf, field_number, /*WT_LEN=*/2); + encode_tag( + buf, field_number, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); encode_varint(buf, len); auto const* p = static_cast(data); buf.insert(buf.end(), p, p + len); @@ -94,7 +101,7 @@ void encode_string_field(std::vector& buf, int field_number, std::strin encode_len_field(buf, field_number, s.data(), s.size()); } -// Encode a nested message: write its content into a temporary buffer, then emit as WT_LEN. +// Encode a nested message: write its content into a temporary buffer, then emit as LEN. template void encode_nested_message(std::vector& buf, int field_number, Fn&& content_fn) { @@ -178,7 +185,9 @@ void encode_string_field_record(std::vector& buf, std::vector& out_occurrences, int32_t row_idx) { - encode_tag(buf, field_number, /*WT_LEN=*/2); + encode_tag(buf, + field_number, + /*spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)=*/2); encode_varint(buf, s.size()); auto const data_offset = checked_size_to_i32(buf.size(), "string field data offset"); buf.insert(buf.end(), s.begin(), s.end()); @@ -209,21 +218,36 @@ struct FlatScalarCase { cudf::type_id::FLOAT32, cudf::type_id::FLOAT64, cudf::type_id::BOOL8}; - int wt_for_type[] = {0 /*WT_VARINT*/, 0, 5 /*WT_32BIT*/, 1 /*WT_64BIT*/, 0}; + int wt_for_type[] = { + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)}; int fn = 1; for (int i = 0; i < num_int_fields; i++, fn++) { int ti = i % 5; auto ty = int_types[ti]; int wt = wt_for_type[ti]; - int enc = spark_rapids_jni::ENC_DEFAULT; - if (ty == cudf::type_id::FLOAT32) enc = spark_rapids_jni::ENC_FIXED; - if (ty == cudf::type_id::FLOAT64) enc = spark_rapids_jni::ENC_FIXED; + int enc = spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT); + if (ty == cudf::type_id::FLOAT32) + enc = spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED); + if (ty == cudf::type_id::FLOAT64) + enc = spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED); ctx.schema.push_back({fn, -1, 0, wt, ty, enc, false, false, false}); } for (int i = 0; i < num_string_fields; i++, fn++) { ctx.schema.push_back( - {fn, -1, 0, 2 /*WT_LEN*/, cudf::type_id::STRING, 0, false, false, false}); + {fn, + -1, + 0, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN), + cudf::type_id::STRING, + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + false, + false, + false}); } size_t n = ctx.schema.size(); @@ -500,13 +524,18 @@ struct WideRepeatedMessageCase { cudf::type_id::FLOAT64, cudf::type_id::BOOL8, cudf::type_id::STRING}; - int child_wt[] = {0, 0, 5, 1, 0, 2}; - int child_enc[] = {spark_rapids_jni::ENC_DEFAULT, - spark_rapids_jni::ENC_DEFAULT, - spark_rapids_jni::ENC_FIXED, - spark_rapids_jni::ENC_FIXED, - spark_rapids_jni::ENC_DEFAULT, - spark_rapids_jni::ENC_DEFAULT}; + int child_wt[] = {spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)}; + int child_enc[] = {spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT)}; // Keep strings sparse so the case remains dominated by wide child scanning // rather than varlen copy traffic. @@ -1046,12 +1075,14 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) std::vector h_schema( num_repeated_children); for (int i = 0; i < num_repeated_children; i++) { - h_schema[i].field_number = i + 1; - h_schema[i].parent_idx = -1; - h_schema[i].depth = 0; - h_schema[i].wire_type = spark_rapids_jni::protobuf_detail::WT_LEN; - h_schema[i].output_type_id = static_cast(cudf::type_id::STRING); - h_schema[i].encoding = 0; + h_schema[i].field_number = i + 1; + h_schema[i].parent_idx = -1; + h_schema[i].depth = 0; + h_schema[i].wire_type = + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN); + h_schema[i].output_type_id = static_cast(cudf::type_id::STRING); + h_schema[i].encoding = + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT); h_schema[i].is_repeated = true; h_schema[i].is_required = false; h_schema[i].has_default_value = false; diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index a77fced341..d1fc753a62 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -112,9 +112,18 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "field_numbers must be positive", 0); } - if (!(wire_type == 0 || wire_type == 1 || wire_type == 2 || wire_type == 5)) { - JNI_THROW_NEW( - env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "wire_types must be one of {0,1,2,5}", 0); + if (!(wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) || + wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT) || + wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) || + wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT))) { + JNI_THROW_NEW(env, + cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, + "wire_types must be one of {VARINT,I64BIT,LEN,I32BIT}", + 0); } if (parent_idx < -1 || parent_idx >= num_fields || parent_idx >= i) { diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf.cu index 7892e4fdee..5a7ce0540c 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf.cu @@ -35,6 +35,7 @@ void apply_parent_mask_to_row_aligned_column(cudf::column& col, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { + if (parent_null_count == 0) { return; } auto child_view = col.mutable_view(); CUDF_EXPECTS(child_view.size() == num_rows, "struct child size must match parent row count for null propagation"); @@ -52,6 +53,8 @@ void apply_parent_mask_to_row_aligned_column(cudf::column& col, stream); col.set_null_count(child_view.size() - valid_count); } else { + CUDF_EXPECTS(child_view.offset() == 0, + "non-nullable child with nonzero offset not supported for null propagation"); auto child_mask = cudf::detail::copy_bitmask(parent_mask_ptr, 0, num_rows, stream, mr); col.set_null_mask(std::move(child_mask), parent_null_count); } @@ -418,12 +421,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int si = scalar_field_indices[i]; auto tid = schema_output_types[si].id(); int enc = schema[si].encoding; - bool zz = (enc == spark_rapids_jni::ENC_ZIGZAG); + bool zz = + (enc == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG)); // STRING, LIST, and enum-as-string go to per-field path if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) continue; - bool is_fixed = (enc == spark_rapids_jni::ENC_FIXED); + bool is_fixed = + (enc == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED)); // INT32 with enum validation goes to fallback if (tid == cudf::type_id::INT32 && !zz && !is_fixed && @@ -551,10 +556,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& LAUNCH_VARINT_BATCH(4, uint8_t, false); LAUNCH_VARINT_BATCH(5, int32_t, true); LAUNCH_VARINT_BATCH(6, int64_t, true); - LAUNCH_FIXED_BATCH(7, float, WT_32BIT); - LAUNCH_FIXED_BATCH(8, double, WT_64BIT); - LAUNCH_FIXED_BATCH(9, int32_t, WT_32BIT); - LAUNCH_FIXED_BATCH(10, int64_t, WT_64BIT); + LAUNCH_FIXED_BATCH( + 7, float, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH( + 8, double, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)); + LAUNCH_FIXED_BATCH( + 9, int32_t, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH( + 10, int64_t, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)); #undef LAUNCH_VARINT_BATCH #undef LAUNCH_FIXED_BATCH @@ -601,7 +610,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& switch (dt.id()) { case cudf::type_id::STRING: { - if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + if (enc == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { // ENUM-as-string path: // 1. Decode enum numeric value as INT32 varint. // 2. Validate against enum_valid_values. @@ -944,7 +954,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::STRING: { auto const field_meta = make_field_meta_view(context, schema_idx); auto enc = field_meta.schema.encoding; - if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + if (enc == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { if (!field_meta.enum_valid_values.empty() && field_meta.enum_valid_values.size() == field_meta.enum_names.size()) { column_map[schema_idx] = diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp index c574c7cfaa..5a4844a58f 100644 --- a/src/main/cpp/src/protobuf.hpp +++ b/src/main/cpp/src/protobuf.hpp @@ -31,17 +31,31 @@ namespace spark_rapids_jni { // Encoding constants -constexpr int ENC_DEFAULT = 0; -constexpr int ENC_FIXED = 1; -constexpr int ENC_ZIGZAG = 2; -constexpr int ENC_ENUM_STRING = 3; +enum class proto_encoding : int { + DEFAULT = 0, + FIXED = 1, + ZIGZAG = 2, + ENUM_STRING = 3, +}; +CUDF_HOST_DEVICE constexpr int encoding_value(proto_encoding encoding) +{ + return static_cast(encoding); +} constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; // Wire type constants -constexpr int WT_VARINT = 0; -constexpr int WT_64BIT = 1; -constexpr int WT_LEN = 2; -constexpr int WT_32BIT = 5; +enum class proto_wire_type : int { + VARINT = 0, + I64BIT = 1, + LEN = 2, + SGROUP = 3, + EGROUP = 4, + I32BIT = 5, +}; +CUDF_HOST_DEVICE constexpr int wire_type_value(proto_wire_type wire_type) +{ + return static_cast(wire_type); +} // Maximum nesting depth for nested messages constexpr int MAX_NESTING_DEPTH = 10; @@ -54,9 +68,9 @@ struct nested_field_descriptor { int field_number; // Protobuf field number int parent_idx; // Index of parent field in schema (-1 for top-level) int depth; // Nesting depth (0 for top-level) - int wire_type; // Expected wire type + int wire_type; // Expected wire type (proto_wire_type) cudf::type_id output_type; // Output cudf type - int encoding; // Encoding type (ENC_DEFAULT, ENC_FIXED, ENC_ZIGZAG) + int encoding; // Encoding type (proto_encoding) bool is_repeated; // Whether this field is repeated (array) bool is_required; // Whether this field is required (proto2) bool has_default_value; // Whether this field has a default value @@ -92,34 +106,50 @@ inline bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type) { switch (field.encoding) { - case ENC_DEFAULT: + case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT): switch (type.id()) { case cudf::type_id::BOOL8: case cudf::type_id::INT32: case cudf::type_id::UINT32: case cudf::type_id::INT64: - case cudf::type_id::UINT64: return field.wire_type == WT_VARINT; - case cudf::type_id::FLOAT32: return field.wire_type == WT_32BIT; - case cudf::type_id::FLOAT64: return field.wire_type == WT_64BIT; + case cudf::type_id::UINT64: + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT); + case cudf::type_id::FLOAT32: + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT); + case cudf::type_id::FLOAT64: + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT); case cudf::type_id::STRING: case cudf::type_id::LIST: - case cudf::type_id::STRUCT: return field.wire_type == WT_LEN; + case cudf::type_id::STRUCT: + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN); default: return false; } - case ENC_FIXED: + case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED): switch (type.id()) { case cudf::type_id::INT32: case cudf::type_id::UINT32: - case cudf::type_id::FLOAT32: return field.wire_type == WT_32BIT; + case cudf::type_id::FLOAT32: + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT); case cudf::type_id::INT64: case cudf::type_id::UINT64: - case cudf::type_id::FLOAT64: return field.wire_type == WT_64BIT; + case cudf::type_id::FLOAT64: + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT); default: return false; } - case ENC_ZIGZAG: - return field.wire_type == WT_VARINT && + case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG): + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) && (type.id() == cudf::type_id::INT32 || type.id() == cudf::type_id::INT64); - case ENC_ENUM_STRING: return field.wire_type == WT_VARINT && type.id() == cudf::type_id::STRING; + case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING): + return field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) && + type.id() == cudf::type_id::STRING; default: return false; } } @@ -190,12 +220,21 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) std::to_string(i)); } } - if (!(field.wire_type == WT_VARINT || field.wire_type == WT_64BIT || - field.wire_type == WT_LEN || field.wire_type == WT_32BIT)) { + if (!(field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) || + field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT) || + field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) || + field.wire_type == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT))) { throw std::invalid_argument("protobuf decode context: invalid wire type at field " + std::to_string(i)); } - if (field.encoding < ENC_DEFAULT || field.encoding > ENC_ENUM_STRING) { + if (field.encoding < + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT) || + field.encoding > + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { throw std::invalid_argument("protobuf decode context: invalid encoding at field " + std::to_string(i)); } @@ -215,7 +254,8 @@ inline void validate_decode_context(ProtobufDecodeContext const& context) "protobuf decode context: incompatible wire type/encoding/output type at field " + std::to_string(i)); } - if (field.encoding == ENC_ENUM_STRING) { + if (field.encoding == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { if (context.enum_valid_values[i].empty() || context.enum_names[i].empty()) { throw std::invalid_argument( "protobuf decode context: enum-as-string field requires non-empty metadata at field " + diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf_builders.cu index 113502a7d2..9720d783be 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf_builders.cu @@ -940,7 +940,8 @@ std::unique_ptr build_repeated_struct_column( break; } case cudf::type_id::STRING: { - if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + if (enc == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && @@ -1250,7 +1251,8 @@ std::unique_ptr build_nested_struct_column( break; } case cudf::type_id::STRING: { - if (enc == spark_rapids_jni::ENC_ENUM_STRING) { + if (enc == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; @@ -1594,7 +1596,8 @@ std::unique_ptr build_repeated_child_list_column( propagate_invalid_rows); } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { if (elem_type_id == cudf::type_id::STRING && - schema[child_schema_idx].encoding == spark_rapids_jni::ENC_ENUM_STRING) { + schema[child_schema_idx].encoding == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf_common.cuh index 2866cc1c08..6178e5e970 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf_common.cuh @@ -53,14 +53,6 @@ namespace spark_rapids_jni::protobuf_detail { -// Wire type constants (protobuf encoding spec) -constexpr int WT_VARINT = 0; -constexpr int WT_64BIT = 1; -constexpr int WT_LEN = 2; -constexpr int WT_SGROUP = 3; -constexpr int WT_EGROUP = 4; -constexpr int WT_32BIT = 5; - // Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. constexpr int MAX_VARINT_BYTES = 10; @@ -208,7 +200,7 @@ inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stre __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { switch (wt) { - case WT_VARINT: { + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT): { // Need to scan to find the end of varint int count = 0; while (cur < end && count < MAX_VARINT_BYTES) { @@ -217,15 +209,15 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; // Invalid varint } - case WT_64BIT: + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT): // Check if there's enough data for 8 bytes if (end - cur < 8) return -1; return 8; - case WT_32BIT: + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT): // Check if there's enough data for 4 bytes if (end - cur < 4) return -1; return 4; - case WT_LEN: { + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN): { uint64_t len; int n; if (!read_varint(cur, end, len, n)) return -1; @@ -233,7 +225,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return -1; return n + static_cast(len); } - case WT_SGROUP: { + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::SGROUP): { auto const* start = cur; int depth = 1; while (cur < end && depth > 0) { @@ -243,23 +235,27 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con cur += key_bytes; int inner_wt = static_cast(key & 0x7); - if (inner_wt == WT_EGROUP) { + if (inner_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::EGROUP)) { --depth; if (depth == 0) { return static_cast(cur - start); } - } else if (inner_wt == WT_SGROUP) { + } else if (inner_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::SGROUP)) { if (++depth > 32) return -1; } else { int inner_size = -1; switch (inner_wt) { - case WT_VARINT: { + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT): { uint64_t dummy; int vbytes; if (!read_varint(cur, end, dummy, vbytes)) return -1; inner_size = vbytes; break; } - case WT_64BIT: inner_size = 8; break; - case WT_LEN: { + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT): + inner_size = 8; + break; + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN): { uint64_t len; int len_bytes; if (!read_varint(cur, end, len, len_bytes)) return -1; @@ -267,7 +263,9 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con inner_size = len_bytes + static_cast(len); break; } - case WT_32BIT: inner_size = 4; break; + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT): + inner_size = 4; + break; default: return -1; } if (inner_size < 0 || cur + inner_size > end) return -1; @@ -276,7 +274,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; } - case WT_EGROUP: return 0; + case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::EGROUP): return 0; default: return -1; } } @@ -287,9 +285,12 @@ __device__ inline bool skip_field(uint8_t const* cur, uint8_t const*& out_cur) { // A bare end-group is only valid while a start-group payload is being parsed recursively inside - // get_wire_type_size(WT_SGROUP). The scan/count kernels should never accept it as a standalone - // field because Spark CPU treats unmatched end-groups as malformed protobuf. - if (wt == WT_EGROUP) { return false; } + // get_wire_type_size(spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::SGROUP)). + // The scan/count kernels should never accept it as a standalone field because Spark CPU treats + // unmatched end-groups as malformed protobuf. + if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::EGROUP)) { + return false; + } int size = get_wire_type_size(wt, cur, end); if (size < 0) return false; @@ -306,7 +307,7 @@ __device__ inline bool skip_field(uint8_t const* cur, __device__ inline bool get_field_data_location( uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) { - if (wt == WT_LEN) { + if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { // For length-delimited, read the length prefix uint64_t len; int len_bytes; @@ -654,7 +655,8 @@ __global__ void extract_fixed_kernel(uint8_t const* message_data, uint8_t const* cur = message_data + data_offset; OutputType value; - if constexpr (WT == WT_32BIT) { + if constexpr (WT == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); if (valid) valid[idx] = false; @@ -779,7 +781,8 @@ __global__ void extract_fixed_batched_kernel(uint8_t const* message_data, uint8_t const* cur = message_data + data_offset; OutputType value; - if constexpr (WT == WT_32BIT) { + if constexpr (WT == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); desc.valid[row] = false; @@ -897,7 +900,8 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, int* error_ptr, rmm::cuda_stream_view stream) { - if (enable_zigzag && encoding == spark_rapids_jni::ENC_ZIGZAG) { + if (enable_zigzag && + encoding == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG)) { extract_varint_kernel <<>>(message_data, loc_provider, @@ -907,9 +911,13 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, error_ptr, has_default, default_value); - } else if (encoding == spark_rapids_jni::ENC_FIXED) { + } else if (encoding == + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED)) { if constexpr (sizeof(T) == 4) { - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_rows, @@ -920,7 +928,10 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, static_cast(default_value)); } else { static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_rows, @@ -1669,7 +1680,10 @@ inline std::unique_ptr extract_typed_column( dt, num_items, [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_items, @@ -1688,7 +1702,10 @@ inline std::unique_ptr extract_typed_column( dt, num_items, [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_items, @@ -1764,23 +1781,31 @@ inline std::unique_ptr build_repeated_scalar_column( auto const blocks = static_cast((total_count + threads - 1u) / threads); int encoding = field_desc.encoding; - bool zigzag = (encoding == spark_rapids_jni::ENC_ZIGZAG); + bool zigzag = + (encoding == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG)); // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) - // For integer types, use fixed kernel only if encoding is ENC_FIXED + // For integer types, use fixed kernel only if encoding is + // spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED) constexpr bool is_floating_point = std::is_same_v || std::is_same_v; - bool use_fixed_kernel = is_floating_point || (encoding == spark_rapids_jni::ENC_FIXED); + bool use_fixed_kernel = + is_floating_point || + (encoding == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED)); RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; if (use_fixed_kernel) { if constexpr (sizeof(T) == 4) { - extract_fixed_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + extract_fixed_kernel<<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { - extract_fixed_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + extract_fixed_kernel<<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } } else if (zigzag) { extract_varint_kernel diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf_kernels.cu index fa3cc54999..118c918ebd 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf_kernels.cu @@ -99,7 +99,7 @@ __global__ void scan_all_fields_kernel( // Record the location (relative to message start) int data_offset = static_cast(cur - bytes - start); - if (wt == WT_LEN) { + if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { // For length-delimited, record offset after length prefix and the data length uint64_t len; int len_bytes; @@ -165,7 +165,9 @@ __device__ bool count_repeated_element(uint8_t const* cur, repeated_field_info& info, int* error_flag) { - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + bool is_packed = + (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) && + expected_wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); if (!is_packed && wt != expected_wt) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -187,7 +189,8 @@ __device__ bool count_repeated_element(uint8_t const* cur, uint8_t const* packed_end = packed_start + packed_len; int count = 0; - if (expected_wt == WT_VARINT) { + if (expected_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)) { uint8_t const* p = packed_start; while (p < packed_end) { uint64_t dummy; @@ -199,13 +202,15 @@ __device__ bool count_repeated_element(uint8_t const* cur, p += vbytes; count++; } - } else if (expected_wt == WT_32BIT) { + } else if (expected_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { if ((packed_len % 4) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; } count = static_cast(packed_len / 4); - } else if (expected_wt == WT_64BIT) { + } else if (expected_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)) { if ((packed_len % 8) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -244,7 +249,9 @@ __device__ bool scan_repeated_element(uint8_t const* cur, int write_end, int* error_flag) { - bool is_packed = (wt == WT_LEN && expected_wt != WT_LEN); + bool is_packed = + (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) && + expected_wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); if (!is_packed && wt != expected_wt) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -265,7 +272,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, } uint8_t const* packed_end = packed_start + packed_len; - if (expected_wt == WT_VARINT) { + if (expected_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)) { uint8_t const* p = packed_start; while (p < packed_end) { int32_t elem_offset = static_cast(p - msg_base); @@ -283,7 +291,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, write_idx++; p += vbytes; } - } else if (expected_wt == WT_32BIT) { + } else if (expected_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { if ((packed_len % 4) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -296,7 +305,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 4}; write_idx++; } - } else if (expected_wt == WT_64BIT) { + } else if (expected_wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)) { if ((packed_len % 8) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -428,7 +438,7 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in // Check nested message fields at this depth auto handle_nested = [&](int i) { - if (wt != WT_LEN) { + if (wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { set_error_once(error_flag, ERR_WIRE_TYPE); return false; } @@ -531,8 +541,10 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co int wt = tag.wire_type; auto try_scan = [&](int f) -> bool { - int target_wt = scan_descs[f].wire_type; - bool is_packed = (wt == WT_LEN && target_wt != WT_LEN); + int target_wt = scan_descs[f].wire_type; + bool is_packed = + (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) && + target_wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); if (is_packed || wt == target_wt) { return scan_repeated_element(cur, msg_end, @@ -645,7 +657,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, int data_offset = static_cast(cur - nested_start); - if (wt == WT_LEN) { + if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { uint64_t len; int len_bytes; if (!read_varint(cur, nested_end, len, len_bytes)) { @@ -748,7 +760,7 @@ __global__ void scan_repeated_message_children_kernel( } else { int data_offset = static_cast(cur - msg_start); - if (wt == WT_LEN) { + if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { @@ -772,7 +784,7 @@ __global__ void scan_repeated_message_children_kernel( } else { // For varint/fixed types, store offset and estimated length int32_t data_length = 0; - if (wt == WT_VARINT) { + if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)) { uint64_t dummy; int vbytes; if (!read_varint(cur, msg_end, dummy, vbytes)) { @@ -780,13 +792,15 @@ __global__ void scan_repeated_message_children_kernel( return; } data_length = vbytes; - } else if (wt == WT_32BIT) { + } else if (wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { if (msg_end - cur < 4) { set_error_once(error_flag, ERR_FIXED_LEN); return; } data_length = 4; - } else if (wt == WT_64BIT) { + } else if (wt == + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)) { if (msg_end - cur < 8) { set_error_once(error_flag, ERR_FIXED_LEN); return; From 0213433dd980ed1127020bb36fc6375c19917611 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Mar 2026 10:12:36 +0800 Subject: [PATCH 093/107] port back refactor from pr 0 Signed-off-by: Haoyang Li --- src/main/cpp/CMakeLists.txt | 6 +- src/main/cpp/src/ProtobufJni.cpp | 83 +- src/main/cpp/src/protobuf.hpp | 319 ------ src/main/cpp/src/{ => protobuf}/protobuf.cu | 271 ++++- src/main/cpp/src/protobuf/protobuf.hpp | 109 ++ .../src/{ => protobuf}/protobuf_builders.cu | 24 +- .../src/protobuf/protobuf_device_helpers.cuh | 300 ++++++ .../protobuf_host_helpers.hpp} | 971 +----------------- .../src/{ => protobuf}/protobuf_kernels.cu | 76 +- .../cpp/src/protobuf/protobuf_kernels.cuh | 571 ++++++++++ src/main/cpp/src/protobuf/protobuf_types.cuh | 162 +++ 11 files changed, 1453 insertions(+), 1439 deletions(-) delete mode 100644 src/main/cpp/src/protobuf.hpp rename src/main/cpp/src/{ => protobuf}/protobuf.cu (83%) create mode 100644 src/main/cpp/src/protobuf/protobuf.hpp rename src/main/cpp/src/{ => protobuf}/protobuf_builders.cu (99%) create mode 100644 src/main/cpp/src/protobuf/protobuf_device_helpers.cuh rename src/main/cpp/src/{protobuf_common.cuh => protobuf/protobuf_host_helpers.hpp} (52%) rename src/main/cpp/src/{ => protobuf}/protobuf_kernels.cu (94%) create mode 100644 src/main/cpp/src/protobuf/protobuf_kernels.cuh create mode 100644 src/main/cpp/src/protobuf/protobuf_types.cuh diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index f7a2d0f0c4..1ed650d111 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -255,9 +255,9 @@ add_library( src/multiply.cu src/number_converter.cu src/parse_uri.cu - src/protobuf.cu - src/protobuf_kernels.cu - src/protobuf_builders.cu + src/protobuf/protobuf.cu + src/protobuf/protobuf_kernels.cu + src/protobuf/protobuf_builders.cu src/regex_rewrite_utils.cu src/row_conversion.cu src/round_float.cu diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index d1fc753a62..4bfc9a9937 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -16,7 +16,7 @@ #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" -#include "protobuf.hpp" +#include "protobuf/protobuf.hpp" #include #include @@ -98,67 +98,16 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, 0); } - // Validate schema topology and wire types: - // - parent index must be -1 or a prior field index - // - depth must be 0 for top-level and parent_depth + 1 for children - // - wire type must be one of {0, 1, 2, 5} - for (int i = 0; i < num_fields; ++i) { - auto const parent_idx = n_parent_indices[i]; - auto const depth = n_depth_levels[i]; - auto const wire_type = n_wire_types[i]; - - if (n_field_numbers[i] <= 0) { - JNI_THROW_NEW( - env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "field_numbers must be positive", 0); - } - - if (!(wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) || - wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT) || - wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) || - wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT))) { - JNI_THROW_NEW(env, - cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "wire_types must be one of {VARINT,I64BIT,LEN,I32BIT}", - 0); - } - - if (parent_idx < -1 || parent_idx >= num_fields || parent_idx >= i) { - JNI_THROW_NEW(env, - cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "parent_indices must be -1 or a valid prior field index", - 0); - } - - if (parent_idx == -1) { - if (depth != 0) { - JNI_THROW_NEW( - env, cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, "top-level fields must have depth 0", 0); - } - } else { - auto const parent_depth = n_depth_levels[parent_idx]; - if (depth != parent_depth + 1) { - JNI_THROW_NEW(env, - cudf::jni::ILLEGAL_ARG_EXCEPTION_CLASS, - "child depth must equal parent depth + 1", - 0); - } - } - } - // Build schema descriptors - std::vector schema; + std::vector schema; schema.reserve(num_fields); for (int i = 0; i < num_fields; ++i) { schema.push_back({n_field_numbers[i], n_parent_indices[i], n_depth_levels[i], - n_wire_types[i], + static_cast(n_wire_types[i]), static_cast(n_output_type_ids[i]), - n_encodings[i], + static_cast(n_encodings[i]), n_is_repeated[i] != 0, n_is_required[i] != 0, n_has_default[i] != 0}); @@ -266,18 +215,18 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } } - spark_rapids_jni::ProtobufDecodeContext context{std::move(schema), - std::move(schema_output_types), - std::move(default_int_values), - std::move(default_float_values), - std::move(default_bool_values), - std::move(default_string_values), - std::move(enum_values), - std::move(enum_name_values), - static_cast(fail_on_errors)}; - - auto result = - spark_rapids_jni::decode_protobuf_to_struct(*input, context, cudf::get_default_stream()); + spark_rapids_jni::protobuf::ProtobufDecodeContext context{std::move(schema), + std::move(schema_output_types), + std::move(default_int_values), + std::move(default_float_values), + std::move(default_bool_values), + std::move(default_string_values), + std::move(enum_values), + std::move(enum_name_values), + static_cast(fail_on_errors)}; + + auto result = spark_rapids_jni::protobuf::decode_protobuf_to_struct( + *input, context, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); return cudf::jni::release_as_jlong(result); } diff --git a/src/main/cpp/src/protobuf.hpp b/src/main/cpp/src/protobuf.hpp deleted file mode 100644 index 5a4844a58f..0000000000 --- a/src/main/cpp/src/protobuf.hpp +++ /dev/null @@ -1,319 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -#include - -#include -#include -#include -#include -#include - -namespace spark_rapids_jni { - -// Encoding constants -enum class proto_encoding : int { - DEFAULT = 0, - FIXED = 1, - ZIGZAG = 2, - ENUM_STRING = 3, -}; -CUDF_HOST_DEVICE constexpr int encoding_value(proto_encoding encoding) -{ - return static_cast(encoding); -} -constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; - -// Wire type constants -enum class proto_wire_type : int { - VARINT = 0, - I64BIT = 1, - LEN = 2, - SGROUP = 3, - EGROUP = 4, - I32BIT = 5, -}; -CUDF_HOST_DEVICE constexpr int wire_type_value(proto_wire_type wire_type) -{ - return static_cast(wire_type); -} - -// Maximum nesting depth for nested messages -constexpr int MAX_NESTING_DEPTH = 10; - -/** - * Descriptor for a field in a nested protobuf schema. - * Used to represent flattened schema with parent-child relationships. - */ -struct nested_field_descriptor { - int field_number; // Protobuf field number - int parent_idx; // Index of parent field in schema (-1 for top-level) - int depth; // Nesting depth (0 for top-level) - int wire_type; // Expected wire type (proto_wire_type) - cudf::type_id output_type; // Output cudf type - int encoding; // Encoding type (proto_encoding) - bool is_repeated; // Whether this field is repeated (array) - bool is_required; // Whether this field is required (proto2) - bool has_default_value; // Whether this field has a default value -}; - -/** - * Context and schema information for decoding protobuf messages. - */ -struct ProtobufDecodeContext { - std::vector schema; - std::vector schema_output_types; - std::vector default_ints; - std::vector default_floats; - std::vector default_bools; - std::vector> default_strings; - std::vector> enum_valid_values; - std::vector>> enum_names; - bool fail_on_errors; -}; - -struct ProtobufFieldMetaView { - nested_field_descriptor const& schema; - cudf::data_type const& output_type; - int64_t default_int; - double default_float; - bool default_bool; - std::vector const& default_string; - std::vector const& enum_valid_values; - std::vector> const& enum_names; -}; - -inline bool is_encoding_compatible(nested_field_descriptor const& field, - cudf::data_type const& type) -{ - switch (field.encoding) { - case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT): - switch (type.id()) { - case cudf::type_id::BOOL8: - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT); - case cudf::type_id::FLOAT32: - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT); - case cudf::type_id::FLOAT64: - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT); - case cudf::type_id::STRING: - case cudf::type_id::LIST: - case cudf::type_id::STRUCT: - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN); - default: return false; - } - case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED): - switch (type.id()) { - case cudf::type_id::INT32: - case cudf::type_id::UINT32: - case cudf::type_id::FLOAT32: - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT); - case cudf::type_id::INT64: - case cudf::type_id::UINT64: - case cudf::type_id::FLOAT64: - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT); - default: return false; - } - case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG): - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) && - (type.id() == cudf::type_id::INT32 || type.id() == cudf::type_id::INT64); - case spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING): - return field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) && - type.id() == cudf::type_id::STRING; - default: return false; - } -} - -inline void validate_decode_context(ProtobufDecodeContext const& context) -{ - auto const num_fields = context.schema.size(); - auto const fail_size = [&](char const* name, size_t actual) { - throw std::invalid_argument(std::string("protobuf decode context: ") + name + - " size mismatch with schema (" + std::to_string(actual) + " vs " + - std::to_string(num_fields) + ")"); - }; - - if (context.schema_output_types.size() != num_fields) - fail_size("schema_output_types", context.schema_output_types.size()); - if (context.default_ints.size() != num_fields) - fail_size("default_ints", context.default_ints.size()); - if (context.default_floats.size() != num_fields) - fail_size("default_floats", context.default_floats.size()); - if (context.default_bools.size() != num_fields) - fail_size("default_bools", context.default_bools.size()); - if (context.default_strings.size() != num_fields) - fail_size("default_strings", context.default_strings.size()); - if (context.enum_valid_values.size() != num_fields) - fail_size("enum_valid_values", context.enum_valid_values.size()); - if (context.enum_names.size() != num_fields) fail_size("enum_names", context.enum_names.size()); - - std::set> seen_field_numbers; - for (size_t i = 0; i < num_fields; ++i) { - auto const& field = context.schema[i]; - auto const& type = context.schema_output_types[i]; - if (type.id() != field.output_type) { - throw std::invalid_argument( - "protobuf decode context: schema_output_types id mismatch at field " + std::to_string(i)); - } - if (field.field_number <= 0 || field.field_number > MAX_FIELD_NUMBER) { - throw std::invalid_argument("protobuf decode context: invalid field number at field " + - std::to_string(i)); - } - if (field.depth < 0 || field.depth >= MAX_NESTING_DEPTH) { - throw std::invalid_argument( - "protobuf decode context: field depth exceeds supported limit at field " + - std::to_string(i)); - } - if (field.parent_idx < -1 || field.parent_idx >= static_cast(i)) { - throw std::invalid_argument("protobuf decode context: invalid parent index at field " + - std::to_string(i)); - } - if (!seen_field_numbers.emplace(field.parent_idx, field.field_number).second) { - throw std::invalid_argument( - "protobuf decode context: duplicate field number under same parent at field " + - std::to_string(i)); - } - if (field.parent_idx == -1) { - if (field.depth != 0) { - throw std::invalid_argument( - "protobuf decode context: top-level field must have depth 0 at field " + - std::to_string(i)); - } - } else { - auto const& parent = context.schema[field.parent_idx]; - if (field.depth != parent.depth + 1) { - throw std::invalid_argument("protobuf decode context: child depth mismatch at field " + - std::to_string(i)); - } - if (context.schema_output_types[field.parent_idx].id() != cudf::type_id::STRUCT) { - throw std::invalid_argument("protobuf decode context: parent must be STRUCT at field " + - std::to_string(i)); - } - } - if (!(field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT) || - field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT) || - field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) || - field.wire_type == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT))) { - throw std::invalid_argument("protobuf decode context: invalid wire type at field " + - std::to_string(i)); - } - if (field.encoding < - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT) || - field.encoding > - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { - throw std::invalid_argument("protobuf decode context: invalid encoding at field " + - std::to_string(i)); - } - if (field.is_repeated && field.has_default_value) { - throw std::invalid_argument( - "protobuf decode context: repeated field cannot carry default value at field " + - std::to_string(i)); - } - if (field.has_default_value && - (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)) { - throw std::invalid_argument( - "protobuf decode context: STRUCT/LIST field cannot carry default value at field " + - std::to_string(i)); - } - if (!is_encoding_compatible(field, type)) { - throw std::invalid_argument( - "protobuf decode context: incompatible wire type/encoding/output type at field " + - std::to_string(i)); - } - if (field.encoding == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { - if (context.enum_valid_values[i].empty() || context.enum_names[i].empty()) { - throw std::invalid_argument( - "protobuf decode context: enum-as-string field requires non-empty metadata at field " + - std::to_string(i)); - } - if (context.enum_valid_values[i].size() != context.enum_names[i].size()) { - throw std::invalid_argument( - "protobuf decode context: enum-as-string metadata mismatch at field " + - std::to_string(i)); - } - } - } -} - -inline ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, - int schema_idx) -{ - auto const idx = static_cast(schema_idx); - return ProtobufFieldMetaView{context.schema.at(idx), - context.schema_output_types.at(idx), - context.default_ints.at(idx), - context.default_floats.at(idx), - context.default_bools.at(idx), - context.default_strings.at(idx), - context.enum_valid_values.at(idx), - context.enum_names.at(idx)}; -} - -/** - * Decode protobuf messages (one message per row) from a LIST column into a STRUCT - * column, with support for nested messages and repeated fields. - * - * This uses a multi-pass approach: - * - Pass 1: Scan all messages, count nested elements and repeated field occurrences - * - Pass 2: Prefix sum to compute output offsets for arrays and nested structs - * - Pass 3: Extract data using pre-computed offsets - * - Pass 4: Build nested column structure - * - * The schema is represented as a flattened array of field descriptors with parent-child - * relationships. Top-level fields have parent_idx == -1 and depth == 0. For pure scalar - * schemas, all fields are top-level with is_repeated == false. - * - * Supported output child types (cudf dtypes) and corresponding protobuf field types: - * - BOOL8 : protobuf `bool` (varint wire type) - * - INT32 : protobuf `int32`, `sint32` (with zigzag), `fixed32`/`sfixed32` (with fixed encoding) - * - INT64 : protobuf `int64`, `sint64` (with zigzag), `fixed64`/`sfixed64` (with fixed encoding) - * - FLOAT32 : protobuf `float` (fixed32 wire type) - * - FLOAT64 : protobuf `double` (fixed64 wire type) - * - STRING : protobuf `string` (length-delimited wire type, UTF-8 text) - * - LIST : protobuf `bytes` (length-delimited wire type, raw bytes as LIST) - * - STRUCT : protobuf nested `message` - * - * @param binary_input LIST column, each row is one protobuf message - * @param context Decoding context containing schema and default values - * @return STRUCT column with nested structure - */ -std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, - ProtobufDecodeContext const& context, - rmm::cuda_stream_view stream); - -} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu similarity index 83% rename from src/main/cpp/src/protobuf.cu rename to src/main/cpp/src/protobuf/protobuf.cu index 5a7ce0540c..e499e165fd 100644 --- a/src/main/cpp/src/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -14,13 +14,15 @@ * limitations under the License. */ -#include "protobuf_common.cuh" +#include "protobuf/protobuf_host_helpers.hpp" #include +#include +#include -using namespace spark_rapids_jni::protobuf_detail; +using namespace spark_rapids_jni::protobuf::detail; -namespace spark_rapids_jni { +namespace spark_rapids_jni::protobuf { namespace { @@ -132,9 +134,170 @@ void propagate_nulls_to_descendants(cudf::column& col, } // namespace +bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type) +{ + switch (field.encoding) { + case proto_encoding::DEFAULT: + switch (type.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: return field.wire_type == proto_wire_type::VARINT; + case cudf::type_id::FLOAT32: return field.wire_type == proto_wire_type::I32BIT; + case cudf::type_id::FLOAT64: return field.wire_type == proto_wire_type::I64BIT; + case cudf::type_id::STRING: + case cudf::type_id::LIST: + case cudf::type_id::STRUCT: return field.wire_type == proto_wire_type::LEN; + default: return false; + } + case proto_encoding::FIXED: + switch (type.id()) { + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::FLOAT32: return field.wire_type == proto_wire_type::I32BIT; + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT64: return field.wire_type == proto_wire_type::I64BIT; + default: return false; + } + case proto_encoding::ZIGZAG: + return field.wire_type == proto_wire_type::VARINT && + (type.id() == cudf::type_id::INT32 || type.id() == cudf::type_id::INT64); + case proto_encoding::ENUM_STRING: + return field.wire_type == proto_wire_type::VARINT && type.id() == cudf::type_id::STRING; + default: return false; + } +} + +void validate_decode_context(ProtobufDecodeContext const& context) +{ + auto const num_fields = context.schema.size(); + CUDF_EXPECTS(context.schema_output_types.size() == num_fields, + "protobuf decode context: schema_output_types size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_ints.size() == num_fields, + "protobuf decode context: default_ints size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_floats.size() == num_fields, + "protobuf decode context: default_floats size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_bools.size() == num_fields, + "protobuf decode context: default_bools size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_strings.size() == num_fields, + "protobuf decode context: default_strings size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.enum_valid_values.size() == num_fields, + "protobuf decode context: enum_valid_values size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.enum_names.size() == num_fields, + "protobuf decode context: enum_names size mismatch", + std::invalid_argument); + + std::set> seen_field_numbers; + for (size_t i = 0; i < num_fields; ++i) { + auto const& field = context.schema[i]; + auto const& type = context.schema_output_types[i]; + CUDF_EXPECTS( + type.id() == field.output_type, + "protobuf decode context: schema_output_types id mismatch at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(field.field_number > 0 && field.field_number <= MAX_FIELD_NUMBER, + "protobuf decode context: invalid field number at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(field.depth >= 0 && field.depth < MAX_NESTING_DEPTH, + "protobuf decode context: field depth exceeds limit at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(field.parent_idx >= -1 && field.parent_idx < static_cast(i), + "protobuf decode context: invalid parent index at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(seen_field_numbers.emplace(field.parent_idx, field.field_number).second, + "protobuf decode context: duplicate field number under same parent at field " + + std::to_string(i), + std::invalid_argument); + + if (field.parent_idx == -1) { + CUDF_EXPECTS( + field.depth == 0, + "protobuf decode context: top-level field must have depth 0 at field " + std::to_string(i), + std::invalid_argument); + } else { + auto const& parent = context.schema[field.parent_idx]; + CUDF_EXPECTS(field.depth == parent.depth + 1, + "protobuf decode context: child depth mismatch at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(context.schema_output_types[field.parent_idx].id() == cudf::type_id::STRUCT, + "protobuf decode context: parent must be STRUCT at field " + std::to_string(i), + std::invalid_argument); + } + + CUDF_EXPECTS( + field.wire_type == proto_wire_type::VARINT || field.wire_type == proto_wire_type::I64BIT || + field.wire_type == proto_wire_type::LEN || field.wire_type == proto_wire_type::I32BIT, + "protobuf decode context: invalid wire type at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS( + field.encoding >= proto_encoding::DEFAULT && field.encoding <= proto_encoding::ENUM_STRING, + "protobuf decode context: invalid encoding at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(!(field.is_repeated && field.is_required), + "protobuf decode context: field cannot be both repeated and required at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(!(field.is_repeated && field.has_default_value), + "protobuf decode context: repeated field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(!(field.has_default_value && + (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)), + "protobuf decode context: STRUCT/LIST field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(is_encoding_compatible(field, type), + "protobuf decode context: incompatible wire type/encoding/output type at field " + + std::to_string(i), + std::invalid_argument); + + if (field.encoding == proto_encoding::ENUM_STRING) { + CUDF_EXPECTS( + !(context.enum_valid_values[i].empty() || context.enum_names[i].empty()), + "protobuf decode context: enum-as-string field requires non-empty metadata at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS( + context.enum_valid_values[i].size() == context.enum_names[i].size(), + "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i), + std::invalid_argument); + auto const& ev = context.enum_valid_values[i]; + for (size_t j = 1; j < ev.size(); ++j) { + CUDF_EXPECTS( + ev[j] > ev[j - 1], + "protobuf decode context: enum_valid_values must be strictly sorted at field " + + std::to_string(i), + std::invalid_argument); + } + } + } +} + +ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, int schema_idx) +{ + auto const idx = static_cast(schema_idx); + return ProtobufFieldMetaView{context.schema.at(idx), + context.schema_output_types.at(idx), + context.default_ints.at(idx), + context.default_floats.at(idx), + context.default_bools.at(idx), + context.default_strings.at(idx), + context.enum_valid_values.at(idx), + context.enum_names.at(idx)}; +} + std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, ProtobufDecodeContext const& context, - rmm::cuda_stream_view stream) + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) { validate_decode_context(context); auto const& schema = context.schema; @@ -153,7 +316,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& CUDF_EXPECTS(child_type == cudf::type_id::INT8 || child_type == cudf::type_id::UINT8, "binary_input must be a LIST column"); - auto mr = cudf::get_current_device_resource_ref(); auto num_rows = binary_input.size(); auto num_fields = static_cast(schema.size()); @@ -301,10 +463,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_uvector d_fn_to_nested(0, stream, mr); if (num_repeated > 0 || num_nested > 0) { - auto h_fn_to_rep = protobuf_detail::build_index_lookup_table( - schema.data(), repeated_field_indices.data(), num_repeated); - auto h_fn_to_nested = protobuf_detail::build_index_lookup_table( - schema.data(), nested_field_indices.data(), num_nested); + auto h_fn_to_rep = + build_index_lookup_table(schema.data(), repeated_field_indices.data(), num_repeated); + auto h_fn_to_nested = + build_index_lookup_table(schema.data(), nested_field_indices.data(), num_nested); if (!h_fn_to_rep.empty()) { d_fn_to_rep = rmm::device_uvector(h_fn_to_rep.size(), stream, mr); @@ -350,7 +512,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& for (int i = 0; i < num_scalar; i++) { int schema_idx = scalar_field_indices[i]; h_field_descs[i].field_number = schema[schema_idx].field_number; - h_field_descs[i].expected_wire_type = schema[schema_idx].wire_type; + h_field_descs[i].expected_wire_type = static_cast(schema[schema_idx].wire_type); h_field_descs[i].is_repeated = false; } @@ -420,15 +582,13 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& for (int i = 0; i < num_scalar; i++) { int si = scalar_field_indices[i]; auto tid = schema_output_types[si].id(); - int enc = schema[si].encoding; - bool zz = - (enc == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG)); + auto enc = schema[si].encoding; + bool zz = (enc == proto_encoding::ZIGZAG); // STRING, LIST, and enum-as-string go to per-field path if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) continue; - bool is_fixed = - (enc == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED)); + bool is_fixed = (enc == proto_encoding::FIXED); // INT32 with enum validation goes to fallback if (tid == cudf::type_id::INT32 && !zz && !is_fixed && @@ -474,7 +634,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& std::vector> bufs; bufs.reserve(nf); - std::vector h_descs(nf); + std::vector h_descs(nf); for (int j = 0; j < nf; j++) { int li = idxs[j]; @@ -493,19 +653,18 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Build columns for (int j = 0; j < nf; j++) { - int si = scalar_field_indices[idxs[j]]; - auto dt = schema_output_types[si]; - auto& bp = *bufs[j]; - auto [mask, null_count] = - protobuf_detail::make_null_mask_from_valid(bp.valid, stream, mr); - column_map[si] = std::make_unique( + int si = scalar_field_indices[idxs[j]]; + auto dt = schema_output_types[si]; + auto& bp = *bufs[j]; + auto [mask, null_count] = make_null_mask_from_valid(bp.valid, stream, mr); + column_map[si] = std::make_unique( dt, num_rows, bp.out_bytes.release(), std::move(mask), null_count); } }; // Varint launcher for type T with zigzag ZZ auto varint_launch = [&](int nf, - std::vector& h_descs, + std::vector& h_descs, std::vector>& bufs, size_t elem_size, auto kernel_fn) { @@ -513,7 +672,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& bufs[j]->out_bytes = rmm::device_uvector(num_rows * elem_size, stream, mr); h_descs[j].output = bufs[j]->out_bytes.data(); } - rmm::device_uvector d_descs(nf, stream, mr); + rmm::device_uvector d_descs(nf, stream, mr); CUDF_CUDA_TRY(cudaMemcpyAsync(d_descs.data(), h_descs.data(), nf * sizeof(h_descs[0]), @@ -538,14 +697,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& #define LAUNCH_VARINT_BATCH(GROUP, TYPE, ZZ) \ do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ varint_launch(nf, hd, bf, sizeof(TYPE), [](dim3 g, int t, cudaStream_t s, auto... args) { \ - protobuf_detail::extract_varint_batched_kernel<<>>(args...); \ + extract_varint_batched_kernel<<>>(args...); \ }); \ }) #define LAUNCH_FIXED_BATCH(GROUP, TYPE, WT_VAL) \ do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ varint_launch(nf, hd, bf, sizeof(TYPE), [](dim3 g, int t, cudaStream_t s, auto... args) { \ - protobuf_detail::extract_fixed_batched_kernel<<>>(args...); \ + extract_fixed_batched_kernel<<>>(args...); \ }); \ }) @@ -556,14 +715,22 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& LAUNCH_VARINT_BATCH(4, uint8_t, false); LAUNCH_VARINT_BATCH(5, int32_t, true); LAUNCH_VARINT_BATCH(6, int64_t, true); - LAUNCH_FIXED_BATCH( - 7, float, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)); - LAUNCH_FIXED_BATCH( - 8, double, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)); - LAUNCH_FIXED_BATCH( - 9, int32_t, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)); - LAUNCH_FIXED_BATCH( - 10, int64_t, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)); + LAUNCH_FIXED_BATCH(7, + float, + spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH(8, + double, + spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT)); + LAUNCH_FIXED_BATCH(9, + int32_t, + spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH(10, + int64_t, + spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT)); #undef LAUNCH_VARINT_BATCH #undef LAUNCH_FIXED_BATCH @@ -573,7 +740,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int schema_idx = scalar_field_indices[i]; auto const field_meta = make_field_meta_view(context, schema_idx); auto const dt = field_meta.output_type; - auto const enc = field_meta.schema.encoding; + auto const enc = static_cast(field_meta.schema.encoding); bool has_def = field_meta.schema.has_default_value; TopLevelLocationProvider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; @@ -610,8 +777,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& switch (dt.id()) { case cudf::type_id::STRING: { - if (enc == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { + if (enc == proto_encoding::ENUM_STRING) { // ENUM-as-string path: // 1. Decode enum numeric value as INT32 varint. // 2. Validate against enum_valid_values. @@ -657,9 +823,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_str] __device__( cudf::size_type row) { - return locs[protobuf_detail::flat_index(static_cast(row), - static_cast(num_scalar), - static_cast(i))] + return locs[flat_index(static_cast(row), + static_cast(num_scalar), + static_cast(i))] .offset >= 0 || has_def_str; }; @@ -687,9 +853,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__( cudf::size_type row) { - return locs[protobuf_detail::flat_index(static_cast(row), - static_cast(num_scalar), - static_cast(i))] + return locs[flat_index(static_cast(row), + static_cast(num_scalar), + static_cast(i))] .offset >= 0 || has_def_bytes; }; @@ -737,7 +903,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int32_t total_count{0}; rmm::device_uvector counts; rmm::device_uvector offsets; - std::unique_ptr> occurrences; + std::unique_ptr> occurrences; repeated_field_work(int si, cudf::size_type n, @@ -775,24 +941,22 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.synchronize(); // Phase B: Allocate occurrence buffers and launch ONE combined scan kernel. - std::vector h_scan_descs; + std::vector h_scan_descs; h_scan_descs.reserve(num_repeated); for (auto& wp : rep_work) { if (wp->total_count > 0) { wp->occurrences = - std::make_unique>( - wp->total_count, stream, mr); + std::make_unique>(wp->total_count, stream, mr); h_scan_descs.push_back({schema[wp->schema_idx].field_number, - schema[wp->schema_idx].wire_type, + static_cast(schema[wp->schema_idx].wire_type), wp->offsets.data(), wp->occurrences->data()}); } } if (!h_scan_descs.empty()) { - rmm::device_uvector d_scan_descs( - h_scan_descs.size(), stream, mr); + rmm::device_uvector d_scan_descs(h_scan_descs.size(), stream, mr); CUDF_CUDA_TRY(cudaMemcpyAsync(d_scan_descs.data(), h_scan_descs.data(), h_scan_descs.size() * sizeof(h_scan_descs[0]), @@ -806,7 +970,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } rmm::device_uvector d_fn_to_scan(0, stream, mr); int fn_to_scan_size = 0; - if (max_scan_fn <= protobuf_detail::FIELD_LOOKUP_TABLE_MAX) { + if (max_scan_fn <= FIELD_LOOKUP_TABLE_MAX) { std::vector h_fn_to_scan(max_scan_fn + 1, -1); for (int i = 0; i < static_cast(h_scan_descs.size()); i++) { h_fn_to_scan[h_scan_descs[i].field_number] = i; @@ -954,8 +1118,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& case cudf::type_id::STRING: { auto const field_meta = make_field_meta_view(context, schema_idx); auto enc = field_meta.schema.encoding; - if (enc == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { + if (enc == proto_encoding::ENUM_STRING) { if (!field_meta.enum_valid_values.empty() && field_meta.enum_valid_values.size() == field_meta.enum_names.size()) { column_map[schema_idx] = @@ -1204,4 +1367,4 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_rows, std::move(top_level_children), struct_null_count, std::move(struct_mask), stream, mr); } -} // namespace spark_rapids_jni +} // namespace spark_rapids_jni::protobuf diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp new file mode 100644 index 0000000000..452220d399 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace spark_rapids_jni::protobuf { + +enum class proto_encoding : int { + DEFAULT = 0, + FIXED = 1, + ZIGZAG = 2, + ENUM_STRING = 3, +}; + +CUDF_HOST_DEVICE constexpr int encoding_value(proto_encoding encoding) +{ + return static_cast(encoding); +} + +constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; + +enum class proto_wire_type : int { + VARINT = 0, + I64BIT = 1, + LEN = 2, + SGROUP = 3, + EGROUP = 4, + I32BIT = 5, +}; + +CUDF_HOST_DEVICE constexpr int wire_type_value(proto_wire_type wire_type) +{ + return static_cast(wire_type); +} + +constexpr int MAX_NESTING_DEPTH = 10; + +struct nested_field_descriptor { + int field_number; // Protobuf field number + int parent_idx; // Index of parent field in schema (-1 for top-level) + int depth; // Nesting depth (0 for top-level) + proto_wire_type wire_type; // Expected wire type + cudf::type_id output_type; // Output cudf type + proto_encoding encoding; // Encoding type + bool is_repeated; // Whether this field is repeated (array) + bool is_required; // Whether this field is required (proto2) + bool has_default_value; // Whether this field has a default value +}; + +struct ProtobufDecodeContext { + std::vector schema; + std::vector schema_output_types; + std::vector default_ints; + std::vector default_floats; + std::vector default_bools; + std::vector> default_strings; + std::vector> enum_valid_values; + std::vector>> enum_names; + bool fail_on_errors; +}; + +struct ProtobufFieldMetaView { + nested_field_descriptor const& schema; + cudf::data_type const& output_type; + int64_t default_int; + double default_float; + bool default_bool; + std::vector const& default_string; + std::vector const& enum_valid_values; + std::vector> const& enum_names; +}; + +bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type); + +void validate_decode_context(ProtobufDecodeContext const& context); + +ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, int schema_idx); + +std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, + ProtobufDecodeContext const& context, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace spark_rapids_jni::protobuf diff --git a/src/main/cpp/src/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu similarity index 99% rename from src/main/cpp/src/protobuf_builders.cu rename to src/main/cpp/src/protobuf/protobuf_builders.cu index 9720d783be..756948a206 100644 --- a/src/main/cpp/src/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "protobuf_common.cuh" +#include "protobuf/protobuf_host_helpers.hpp" #include -namespace spark_rapids_jni::protobuf_detail { +namespace spark_rapids_jni::protobuf::detail { /** * Helper to build string or bytes column for repeated message child fields. @@ -940,8 +940,8 @@ std::unique_ptr build_repeated_struct_column( break; } case cudf::type_id::STRING: { - if (enc == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { + if (enc == spark_rapids_jni::protobuf::encoding_value( + spark_rapids_jni::protobuf::proto_encoding::ENUM_STRING)) { if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && @@ -1138,7 +1138,7 @@ std::unique_ptr build_nested_struct_column( for (int i = 0; i < num_child_fields; i++) { int child_idx = child_field_indices[i]; h_child_field_descs[i].field_number = schema[child_idx].field_number; - h_child_field_descs[i].expected_wire_type = schema[child_idx].wire_type; + h_child_field_descs[i].expected_wire_type = static_cast(schema[child_idx].wire_type); h_child_field_descs[i].is_repeated = schema[child_idx].is_repeated; } @@ -1181,7 +1181,7 @@ std::unique_ptr build_nested_struct_column( for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; auto const dt = schema_output_types[child_schema_idx]; - auto const enc = schema[child_schema_idx].encoding; + auto const enc = static_cast(schema[child_schema_idx].encoding); bool has_def = schema[child_schema_idx].has_default_value; bool is_repeated = schema[child_schema_idx].is_repeated; @@ -1251,8 +1251,8 @@ std::unique_ptr build_nested_struct_column( break; } case cudf::type_id::STRING: { - if (enc == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { + if (enc == spark_rapids_jni::protobuf::encoding_value( + spark_rapids_jni::protobuf::proto_encoding::ENUM_STRING)) { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; @@ -1474,7 +1474,7 @@ std::unique_ptr build_repeated_child_list_column( device_nested_field_descriptor rep_desc; rep_desc.field_number = schema[child_schema_idx].field_number; - rep_desc.wire_type = schema[child_schema_idx].wire_type; + rep_desc.wire_type = static_cast(schema[child_schema_idx].wire_type); rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); rep_desc.is_repeated = true; rep_desc.parent_idx = -1; @@ -1574,7 +1574,7 @@ std::unique_ptr build_repeated_child_list_column( elem_type_id == cudf::type_id::UINT64 || elem_type_id == cudf::type_id::FLOAT32 || elem_type_id == cudf::type_id::FLOAT64) { child_values = extract_typed_column(cudf::data_type{elem_type_id}, - schema[child_schema_idx].encoding, + static_cast(schema[child_schema_idx].encoding), message_data, nr_loc, total_rep_count, @@ -1597,7 +1597,7 @@ std::unique_ptr build_repeated_child_list_column( } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { if (elem_type_id == cudf::type_id::STRING && schema[child_schema_idx].encoding == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ENUM_STRING)) { + spark_rapids_jni::protobuf::proto_encoding::ENUM_STRING) { if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && @@ -1719,4 +1719,4 @@ std::unique_ptr build_repeated_child_list_column( num_parent_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}); } -} // namespace spark_rapids_jni::protobuf_detail +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh new file mode 100644 index 0000000000..894609b970 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -0,0 +1,300 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf_types.cuh" + +#include + +namespace spark_rapids_jni::protobuf::detail { + +// ============================================================================ +// Device helper functions +// ============================================================================ + +__device__ inline bool read_varint(uint8_t const* cur, + uint8_t const* end, + uint64_t& out, + int& bytes) +{ + out = 0; + bytes = 0; + int shift = 0; + // Protobuf varint uses 7 bits per byte with MSB as continuation flag. + // A 64-bit value requires at most ceil(64/7) = 10 bytes. + while (cur < end && bytes < MAX_VARINT_BYTES) { + uint8_t b = *cur++; + // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid + if (bytes == 9 && (b & 0xFE) != 0) { + return false; // Invalid: 10th byte has more than 1 significant bit + } + out |= (static_cast(b & 0x7Fu) << shift); + bytes++; + if ((b & 0x80u) == 0) { return true; } + shift += 7; + } + return false; +} + +__device__ inline void set_error_once(int* error_flag, int error_code) +{ + int expected = 0; + cuda::atomic_ref ref(*error_flag); + ref.compare_exchange_strong(expected, error_code, cuda::memory_order_relaxed); +} + +__global__ void set_error_if_unset_kernel(int* error_flag, int error_code); + +inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) +{ + set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); + CUDF_CUDA_TRY(cudaPeekAtLastError()); +} + +__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) +{ + switch (wt) { + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::VARINT): { + // Need to scan to find the end of varint + int count = 0; + while (cur < end && count < MAX_VARINT_BYTES) { + if ((*cur++ & 0x80u) == 0) { return count + 1; } + count++; + } + return -1; // Invalid varint + } + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I64BIT): + // Check if there's enough data for 8 bytes + if (end - cur < 8) return -1; + return 8; + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT): + // Check if there's enough data for 4 bytes + if (end - cur < 4) return -1; + return 4; + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::LEN): { + uint64_t len; + int n; + if (!read_varint(cur, end, len, n)) return -1; + if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX - n)) + return -1; + return n + static_cast(len); + } + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP): { + auto const* start = cur; + int depth = 1; + while (cur < end && depth > 0) { + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) return -1; + cur += key_bytes; + + int inner_wt = static_cast(key & 0x7); + if (inner_wt == + spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { + --depth; + if (depth == 0) { return static_cast(cur - start); } + } else if (inner_wt == + spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP)) { + if (++depth > 32) return -1; + } else { + int inner_size = -1; + switch (inner_wt) { + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::VARINT): { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, end, dummy, vbytes)) return -1; + inner_size = vbytes; + break; + } + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I64BIT): + inner_size = 8; + break; + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::LEN): { + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return -1; + if (len > static_cast(INT_MAX - len_bytes)) return -1; + inner_size = len_bytes + static_cast(len); + break; + } + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT): + inner_size = 4; + break; + default: return -1; + } + if (inner_size < 0 || cur + inner_size > end) return -1; + cur += inner_size; + } + } + return -1; + } + case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::EGROUP): return 0; + default: return -1; + } +} + +__device__ inline bool skip_field(uint8_t const* cur, + uint8_t const* end, + int wt, + uint8_t const*& out_cur) +{ + // A bare end-group is only valid while a start-group payload is being parsed recursively inside + // get_wire_type_size(spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP)). + // The scan/count kernels should never accept it as a standalone field because Spark CPU treats + // unmatched end-groups as malformed protobuf. + if (wt == spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { + return false; + } + + int size = get_wire_type_size(wt, cur, end); + if (size < 0) return false; + // Ensure we don't skip past the end of the buffer + if (cur + size > end) return false; + out_cur = cur + size; + return true; +} + +/** + * Get the data offset and length for a field at current position. + * Returns true on success, false on error. + */ +__device__ inline bool get_field_data_location( + uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) +{ + if (wt == spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + // For length-delimited, read the length prefix + uint64_t len; + int len_bytes; + if (!read_varint(cur, end, len, len_bytes)) return false; + if (len > static_cast(end - cur - len_bytes) || + len > static_cast(INT_MAX)) { + return false; + } + data_offset = len_bytes; // offset past the length prefix + data_length = static_cast(len); + } else { + // For fixed-size and varint fields + int field_size = get_wire_type_size(wt, cur, end); + if (field_size < 0) return false; + data_offset = 0; + data_length = field_size; + } + return true; +} + +__device__ __host__ inline size_t flat_index(size_t row, size_t width, size_t col) +{ + return row * width + col; +} + +__device__ inline bool checked_add_int32(int32_t lhs, int32_t rhs, int32_t& out) +{ + auto const sum = static_cast(lhs) + rhs; + if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + return false; + } + out = static_cast(sum); + return true; +} + +__device__ inline bool check_message_bounds(int32_t start, + int32_t end_pos, + cudf::size_type total_size, + int* error_flag) +{ + if (start < 0 || end_pos < start || end_pos > total_size) { + set_error_once(error_flag, ERR_BOUNDS); + return false; + } + return true; +} + +struct proto_tag { + int field_number; + int wire_type; +}; + +__device__ inline bool decode_tag(uint8_t const*& cur, + uint8_t const* end, + proto_tag& tag, + int* error_flag) +{ + uint64_t key; + int key_bytes; + if (!read_varint(cur, end, key, key_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + + cur += key_bytes; + uint64_t fn = key >> 3; + if (fn == 0 || fn > static_cast(spark_rapids_jni::protobuf::MAX_FIELD_NUMBER)) { + set_error_once(error_flag, ERR_FIELD_NUMBER); + return false; + } + tag.field_number = static_cast(fn); + tag.wire_type = static_cast(key & 0x7); + return true; +} + +/** + * Load a little-endian value from unaligned memory. + * Reads bytes individually to avoid unaligned-access issues on GPU. + */ +template +__device__ inline T load_le(uint8_t const* p); + +template <> +__device__ inline uint32_t load_le(uint8_t const* p) +{ + return static_cast(p[0]) | (static_cast(p[1]) << 8) | + (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); +} + +template <> +__device__ inline uint64_t load_le(uint8_t const* p) +{ + uint64_t v = 0; +#pragma unroll + for (int i = 0; i < 8; ++i) { + v |= (static_cast(p[i]) << (8 * i)); + } + return v; +} + +/** + * O(1) lookup of field_number -> field_index using a direct-mapped table. + * Falls back to linear search when the table is empty (field numbers too large). + */ +// Keep this definition in the header so all CUDA translation units can inline it. +__device__ __forceinline__ int lookup_field(int field_number, + int const* lookup_table, + int lookup_table_size, + field_descriptor const* field_descs, + int num_fields) +{ + if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { + return lookup_table[field_number]; + } + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == field_number) return f; + } + return -1; +} + +} // namespace spark_rapids_jni::protobuf::detail + diff --git a/src/main/cpp/src/protobuf_common.cuh b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp similarity index 52% rename from src/main/cpp/src/protobuf_common.cuh rename to src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index 6178e5e970..e41a9eab99 100644 --- a/src/main/cpp/src/protobuf_common.cuh +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -16,397 +16,9 @@ #pragma once -#include "protobuf.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace spark_rapids_jni::protobuf_detail { - -// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. -constexpr int MAX_VARINT_BYTES = 10; - -// CUDA kernel launch configuration. -constexpr int THREADS_PER_BLOCK = 256; - -// Error codes for kernel error reporting. -constexpr int ERR_BOUNDS = 1; -constexpr int ERR_VARINT = 2; -constexpr int ERR_FIELD_NUMBER = 3; -constexpr int ERR_WIRE_TYPE = 4; -constexpr int ERR_OVERFLOW = 5; -constexpr int ERR_FIELD_SIZE = 6; -constexpr int ERR_SKIP = 7; -constexpr int ERR_FIXED_LEN = 8; -constexpr int ERR_REQUIRED = 9; -constexpr int ERR_SCHEMA_TOO_LARGE = 10; -constexpr int ERR_MISSING_ENUM_META = 11; -constexpr int ERR_REPEATED_COUNT_MISMATCH = 12; - -// Maximum supported nesting depth for recursive struct decoding. -constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; - -// Threshold for using a direct-mapped lookup table for field_number -> field_index. -// Field numbers above this threshold fall back to linear search. -constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; +#include "protobuf/protobuf_kernels.cuh" -/** - * Structure to record field location within a message. - * offset < 0 means field was not found. - */ -struct field_location { - int32_t offset; // Offset of field data within the message (-1 if not found) - int32_t length; // Length of field data in bytes -}; - -/** - * Field descriptor passed to the scanning kernel. - */ -struct field_descriptor { - int field_number; // Protobuf field number - int expected_wire_type; // Expected wire type for this field - bool is_repeated; // Repeated children are scanned via count/scan kernels -}; - -/** - * Information about repeated field occurrences in a row. - */ -struct repeated_field_info { - int32_t count; // Number of occurrences in this row - int32_t total_length; // Total bytes for all occurrences (for varlen fields) -}; - -/** - * Location of a single occurrence of a repeated field. - */ -struct repeated_occurrence { - int32_t row_idx; // Which row this occurrence belongs to - int32_t offset; // Offset within the message - int32_t length; // Length of the field data -}; - -/** - * Per-field descriptor passed to the combined occurrence scan kernel. - * Contains device pointers so the kernel can write to each field's output. - */ -struct repeated_field_scan_desc { - int field_number; - int wire_type; - int32_t const* row_offsets; // Pre-computed prefix-sum offsets [num_rows + 1] - repeated_occurrence* occurrences; // Output buffer [total_count] -}; - -/** - * Device-side descriptor for nested schema fields. - */ -struct device_nested_field_descriptor { - int field_number; - int parent_idx; - int depth; - int wire_type; - int output_type_id; - int encoding; - bool is_repeated; - bool is_required; - bool has_default_value; - - device_nested_field_descriptor() = default; - - explicit device_nested_field_descriptor(spark_rapids_jni::nested_field_descriptor const& src) - : field_number(src.field_number), - parent_idx(src.parent_idx), - depth(src.depth), - wire_type(src.wire_type), - output_type_id(static_cast(src.output_type)), - encoding(src.encoding), - is_repeated(src.is_repeated), - is_required(src.is_required), - has_default_value(src.has_default_value) - { - } -}; - -// ============================================================================ -// Device helper functions -// ============================================================================ - -__device__ inline bool read_varint(uint8_t const* cur, - uint8_t const* end, - uint64_t& out, - int& bytes) -{ - out = 0; - bytes = 0; - int shift = 0; - // Protobuf varint uses 7 bits per byte with MSB as continuation flag. - // A 64-bit value requires at most ceil(64/7) = 10 bytes. - while (cur < end && bytes < MAX_VARINT_BYTES) { - uint8_t b = *cur++; - // For the 10th byte (bytes == 9, shift == 63), only the lowest bit is valid - if (bytes == 9 && (b & 0xFE) != 0) { - return false; // Invalid: 10th byte has more than 1 significant bit - } - out |= (static_cast(b & 0x7Fu) << shift); - bytes++; - if ((b & 0x80u) == 0) { return true; } - shift += 7; - } - return false; -} - -__device__ inline void set_error_once(int* error_flag, int error_code) -{ - atomicCAS(error_flag, 0, error_code); -} - -__global__ void set_error_if_unset_kernel(int* error_flag, int error_code); - -inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) -{ - set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); - CUDF_CUDA_TRY(cudaPeekAtLastError()); -} - -__device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) -{ - switch (wt) { - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT): { - // Need to scan to find the end of varint - int count = 0; - while (cur < end && count < MAX_VARINT_BYTES) { - if ((*cur++ & 0x80u) == 0) { return count + 1; } - count++; - } - return -1; // Invalid varint - } - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT): - // Check if there's enough data for 8 bytes - if (end - cur < 8) return -1; - return 8; - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT): - // Check if there's enough data for 4 bytes - if (end - cur < 4) return -1; - return 4; - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN): { - uint64_t len; - int n; - if (!read_varint(cur, end, len, n)) return -1; - if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX - n)) - return -1; - return n + static_cast(len); - } - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::SGROUP): { - auto const* start = cur; - int depth = 1; - while (cur < end && depth > 0) { - uint64_t key; - int key_bytes; - if (!read_varint(cur, end, key, key_bytes)) return -1; - cur += key_bytes; - - int inner_wt = static_cast(key & 0x7); - if (inner_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::EGROUP)) { - --depth; - if (depth == 0) { return static_cast(cur - start); } - } else if (inner_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::SGROUP)) { - if (++depth > 32) return -1; - } else { - int inner_size = -1; - switch (inner_wt) { - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT): { - uint64_t dummy; - int vbytes; - if (!read_varint(cur, end, dummy, vbytes)) return -1; - inner_size = vbytes; - break; - } - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT): - inner_size = 8; - break; - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN): { - uint64_t len; - int len_bytes; - if (!read_varint(cur, end, len, len_bytes)) return -1; - if (len > static_cast(INT_MAX - len_bytes)) return -1; - inner_size = len_bytes + static_cast(len); - break; - } - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT): - inner_size = 4; - break; - default: return -1; - } - if (inner_size < 0 || cur + inner_size > end) return -1; - cur += inner_size; - } - } - return -1; - } - case spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::EGROUP): return 0; - default: return -1; - } -} - -__device__ inline bool skip_field(uint8_t const* cur, - uint8_t const* end, - int wt, - uint8_t const*& out_cur) -{ - // A bare end-group is only valid while a start-group payload is being parsed recursively inside - // get_wire_type_size(spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::SGROUP)). - // The scan/count kernels should never accept it as a standalone field because Spark CPU treats - // unmatched end-groups as malformed protobuf. - if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::EGROUP)) { - return false; - } - - int size = get_wire_type_size(wt, cur, end); - if (size < 0) return false; - // Ensure we don't skip past the end of the buffer - if (cur + size > end) return false; - out_cur = cur + size; - return true; -} - -/** - * Get the data offset and length for a field at current position. - * Returns true on success, false on error. - */ -__device__ inline bool get_field_data_location( - uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) -{ - if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { - // For length-delimited, read the length prefix - uint64_t len; - int len_bytes; - if (!read_varint(cur, end, len, len_bytes)) return false; - if (len > static_cast(end - cur - len_bytes) || - len > static_cast(INT_MAX)) { - return false; - } - data_offset = len_bytes; // offset past the length prefix - data_length = static_cast(len); - } else { - // For fixed-size and varint fields - int field_size = get_wire_type_size(wt, cur, end); - if (field_size < 0) return false; - data_offset = 0; - data_length = field_size; - } - return true; -} - -__device__ __host__ inline size_t flat_index(size_t row, size_t width, size_t col) -{ - return row * width + col; -} - -__device__ inline bool checked_add_int32(int32_t lhs, int32_t rhs, int32_t& out) -{ - auto const sum = static_cast(lhs) + rhs; - if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { - return false; - } - out = static_cast(sum); - return true; -} - -__device__ inline bool check_message_bounds(int32_t start, - int32_t end_pos, - cudf::size_type total_size, - int* error_flag) -{ - if (start < 0 || end_pos < start || end_pos > total_size) { - set_error_once(error_flag, ERR_BOUNDS); - return false; - } - return true; -} - -struct proto_tag { - int field_number; - int wire_type; -}; - -__device__ inline bool decode_tag(uint8_t const*& cur, - uint8_t const* end, - proto_tag& tag, - int* error_flag) -{ - uint64_t key; - int key_bytes; - if (!read_varint(cur, end, key, key_bytes)) { - set_error_once(error_flag, ERR_VARINT); - return false; - } - - cur += key_bytes; - uint64_t fn = key >> 3; - if (fn == 0 || fn > static_cast(spark_rapids_jni::MAX_FIELD_NUMBER)) { - set_error_once(error_flag, ERR_FIELD_NUMBER); - return false; - } - tag.field_number = static_cast(fn); - tag.wire_type = static_cast(key & 0x7); - return true; -} - -/** - * Load a little-endian value from unaligned memory. - * Reads bytes individually to avoid unaligned-access issues on GPU. - */ -template -__device__ inline T load_le(uint8_t const* p); - -template <> -__device__ inline uint32_t load_le(uint8_t const* p) -{ - return static_cast(p[0]) | (static_cast(p[1]) << 8) | - (static_cast(p[2]) << 16) | (static_cast(p[3]) << 24); -} - -template <> -__device__ inline uint64_t load_le(uint8_t const* p) -{ - uint64_t v = 0; -#pragma unroll - for (int i = 0; i < 8; ++i) { - v |= (static_cast(p[i]) << (8 * i)); - } - return v; -} +namespace spark_rapids_jni::protobuf::detail { // ============================================================================ // Field number lookup table helpers @@ -451,411 +63,6 @@ inline std::vector build_field_lookup_table(field_descriptor const* descs, return table; } -/** - * O(1) lookup of field_number -> field_index using a direct-mapped table. - * Falls back to linear search when the table is empty (field numbers too large). - */ -// Keep this definition in the header so all CUDA translation units can inline it. -__device__ __forceinline__ int lookup_field(int field_number, - int const* lookup_table, - int lookup_table_size, - field_descriptor const* field_descs, - int num_fields) -{ - if (lookup_table != nullptr && field_number > 0 && field_number < lookup_table_size) { - return lookup_table[field_number]; - } - for (int f = 0; f < num_fields; f++) { - if (field_descs[f].field_number == field_number) return f; - } - return -1; -} - -// ============================================================================ -// Pass 2: Extract data kernels -// ============================================================================ - -// ============================================================================ -// Data Extraction Location Providers -// ============================================================================ - -struct TopLevelLocationProvider { - cudf::size_type const* offsets; - cudf::size_type base_offset; - field_location const* locations; - int field_idx; - int num_fields; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto loc = locations[flat_index(static_cast(thread_idx), - static_cast(num_fields), - static_cast(field_idx))]; - if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } - return loc; - } -}; - -struct RepeatedLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - repeated_occurrence const* occurrences; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto occ = occurrences[thread_idx]; - data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; - return {occ.offset, occ.length}; - } -}; - -struct NestedLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - field_location const* parent_locations; - field_location const* child_locations; - int field_idx; - int num_fields; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto ploc = parent_locations[thread_idx]; - auto cloc = child_locations[flat_index(static_cast(thread_idx), - static_cast(num_fields), - static_cast(field_idx))]; - if (ploc.offset >= 0 && cloc.offset >= 0) { - data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; - } else { - cloc.offset = -1; - } - return cloc; - } -}; - -struct NestedRepeatedLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - field_location const* parent_locations; - repeated_occurrence const* occurrences; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto occ = occurrences[thread_idx]; - auto ploc = parent_locations[occ.row_idx]; - if (ploc.offset >= 0) { - data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; - return {occ.offset, occ.length}; - } - data_offset = 0; - return {-1, 0}; - } -}; - -struct RepeatedMsgChildLocationProvider { - cudf::size_type const* row_offsets; - cudf::size_type base_offset; - field_location const* msg_locations; - field_location const* child_locations; - int field_idx; - int num_fields; - - __device__ inline field_location get(int thread_idx, int32_t& data_offset) const - { - auto mloc = msg_locations[thread_idx]; - auto cloc = child_locations[flat_index(static_cast(thread_idx), - static_cast(num_fields), - static_cast(field_idx))]; - if (mloc.offset >= 0 && cloc.offset >= 0) { - data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; - } else { - cloc.offset = -1; - } - return cloc; - } -}; - -template -__global__ void extract_varint_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - OutputType* out, - bool* valid, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - // For BOOL8 (uint8_t), protobuf spec says any non-zero varint is true. - // A raw static_cast would silently truncate values >= 256 to 0. - auto const write_value = [](OutputType* dst, uint64_t val) { - if constexpr (std::is_same_v) { - *dst = static_cast(val != 0 ? 1 : 0); - } else { - *dst = static_cast(val); - } - }; - - if (loc.offset < 0) { - if (has_default) { - write_value(&out[idx], static_cast(default_value)); - if (valid) valid[idx] = true; - } else { - if (valid) valid[idx] = false; - } - return; - } - - uint8_t const* cur = message_data + data_offset; - uint8_t const* cur_end = cur + loc.length; - - uint64_t v; - int n; - if (!read_varint(cur, cur_end, v, n)) { - set_error_once(error_flag, ERR_VARINT); - if (valid) valid[idx] = false; - return; - } - - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - write_value(&out[idx], v); - if (valid) valid[idx] = true; -} - -template -__global__ void extract_fixed_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - OutputType* out, - bool* valid, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - if (loc.offset < 0) { - if (has_default) { - out[idx] = default_value; - if (valid) valid[idx] = true; - } else { - if (valid) valid[idx] = false; - } - return; - } - - uint8_t const* cur = message_data + data_offset; - OutputType value; - - if constexpr (WT == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { - if (loc.length < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - if (valid) valid[idx] = false; - return; - } - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - if (loc.length < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - if (valid) valid[idx] = false; - return; - } - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - - out[idx] = value; - if (valid) valid[idx] = true; -} - -// ============================================================================ -// Batched scalar extraction — one 2D kernel for N fields of the same type -// ============================================================================ - -struct batched_scalar_desc { - int loc_field_idx; // index into the locations array (column within d_locations) - void* output; // pre-allocated output buffer (T*) - bool* valid; // pre-allocated validity buffer - bool has_default; - int64_t default_int; - double default_float; -}; - -template -__global__ void extract_varint_batched_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* locations, - int num_loc_fields, - batched_scalar_desc const* descs, - int num_descs, - int num_rows, - int* error_flag) -{ - int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - int fi = static_cast(blockIdx.y); - if (row >= num_rows || fi >= num_descs) return; - - auto const& desc = descs[fi]; - auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; - auto* out = static_cast(desc.output); - - auto const write_value = [](OutputType* dst, uint64_t val) { - if constexpr (std::is_same_v) { - *dst = static_cast(val != 0 ? 1 : 0); - } else { - *dst = static_cast(val); - } - }; - - if (loc.offset < 0) { - if (desc.has_default) { - write_value(&out[row], static_cast(desc.default_int)); - desc.valid[row] = true; - } else { - desc.valid[row] = false; - } - return; - } - - int32_t data_offset = row_offsets[row] - base_offset + loc.offset; - uint8_t const* cur = message_data + data_offset; - uint8_t const* end = cur + loc.length; - - uint64_t v; - int n; - if (!read_varint(cur, end, v, n)) { - set_error_once(error_flag, ERR_VARINT); - desc.valid[row] = false; - return; - } - if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } - write_value(&out[row], v); - desc.valid[row] = true; -} - -template -__global__ void extract_fixed_batched_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* locations, - int num_loc_fields, - batched_scalar_desc const* descs, - int num_descs, - int num_rows, - int* error_flag) -{ - int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - int fi = static_cast(blockIdx.y); - if (row >= num_rows || fi >= num_descs) return; - - auto const& desc = descs[fi]; - auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; - auto* out = static_cast(desc.output); - - if (loc.offset < 0) { - if (desc.has_default) { - if constexpr (std::is_integral_v) { - out[row] = static_cast(desc.default_int); - } else { - out[row] = static_cast(desc.default_float); - } - desc.valid[row] = true; - } else { - desc.valid[row] = false; - } - return; - } - - int32_t data_offset = row_offsets[row] - base_offset + loc.offset; - uint8_t const* cur = message_data + data_offset; - OutputType value; - - if constexpr (WT == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { - if (loc.length < 4) { - set_error_once(error_flag, ERR_FIXED_LEN); - desc.valid[row] = false; - return; - } - uint32_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } else { - if (loc.length < 8) { - set_error_once(error_flag, ERR_FIXED_LEN); - desc.valid[row] = false; - return; - } - uint64_t raw = load_le(cur); - memcpy(&value, &raw, sizeof(value)); - } - out[row] = value; - desc.valid[row] = true; -} - -// ============================================================================ - -template -__global__ void extract_lengths_kernel(LocationProvider loc_provider, - int total_items, - int32_t* out_lengths, - bool has_default = false, - int32_t default_length = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - if (loc.offset >= 0) { - out_lengths[idx] = loc.length; - } else if (has_default) { - out_lengths[idx] = default_length; - } else { - out_lengths[idx] = 0; - } -} -template -__global__ void copy_varlen_data_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - cudf::size_type const* output_offsets, - char* output_chars, - int* error_flag, - bool has_default = false, - uint8_t const* default_chars = nullptr, - int default_len = 0) -{ - auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; - - int32_t data_offset = 0; - auto loc = loc_provider.get(idx, data_offset); - - auto out_start = output_offsets[idx]; - - if (loc.offset < 0) { - if (has_default && default_len > 0) { - memcpy(output_chars + out_start, default_chars, default_len); - } - return; - } - - uint8_t const* src = message_data + data_offset; - memcpy(output_chars + out_start, src, loc.length); -} - template inline std::pair make_null_mask_from_valid( rmm::device_uvector const& valid, @@ -901,7 +108,7 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, rmm::cuda_stream_view stream) { if (enable_zigzag && - encoding == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG)) { + encoding == spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)) { extract_varint_kernel <<>>(message_data, loc_provider, @@ -912,11 +119,11 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, has_default, default_value); } else if (encoding == - spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED)) { + spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED)) { if constexpr (sizeof(T) == 4) { extract_fixed_kernel <<>>(message_data, loc_provider, @@ -929,8 +136,8 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, } else { static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); extract_fixed_kernel <<>>(message_data, loc_provider, @@ -1068,143 +275,6 @@ std::unique_ptr make_empty_struct_column_with_schema( return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); } -// ============================================================================ -// Forward declarations of non-template __global__ kernels -// ============================================================================ - -__global__ void scan_all_fields_kernel(cudf::column_device_view const d_in, - field_descriptor const* field_descs, - int num_fields, - int const* field_lookup, - int field_lookup_size, - field_location* locations, - int* error_flag, - bool* row_has_invalid_data); - -__global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int num_fields, - int depth_level, - repeated_field_info* repeated_info, - int num_repeated_fields, - int const* repeated_field_indices, - field_location* nested_locations, - int num_nested_fields, - int const* nested_field_indices, - int* error_flag, - int const* fn_to_rep_idx = nullptr, - int fn_to_rep_size = 0, - int const* fn_to_nested_idx = nullptr, - int fn_to_nested_size = 0); - -__global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, - repeated_field_scan_desc const* scan_descs, - int num_scan_fields, - int* error_flag, - int const* fn_to_desc_idx = nullptr, - int fn_to_desc_size = 0); - -__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - int num_parent_rows, - field_descriptor const* field_descs, - int num_fields, - field_location* output_locations, - int* error_flag); - -__global__ void scan_repeated_message_children_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* msg_row_offsets, - field_location const* msg_locs, - int num_occurrences, - field_descriptor const* child_descs, - int num_child_fields, - field_location* child_locs, - int* error_flag, - int const* child_lookup = nullptr, - int child_lookup_size = 0); - -__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - repeated_field_info* repeated_info, - int num_repeated, - int const* repeated_indices, - int* error_flag); - -__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int32_t const* occ_prefix_sums, - int const* repeated_indices, - repeated_occurrence* occurrences, - int* error_flag); - -__global__ void compute_nested_struct_locations_kernel(field_location const* child_locs, - field_location const* msg_locs, - cudf::size_type const* msg_row_offsets, - int child_idx, - int num_child_fields, - field_location* nested_locs, - cudf::size_type* nested_row_offsets, - int total_count, - int* error_flag); - -__global__ void compute_grandchild_parent_locations_kernel(field_location const* parent_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - field_location* gc_parent_abs, - int num_rows, - int* error_flag); - -__global__ void compute_virtual_parents_for_nested_repeated_kernel( - repeated_occurrence const* occurrences, - cudf::size_type const* row_list_offsets, - field_location const* parent_locations, - cudf::size_type* virtual_row_offsets, - field_location* virtual_parent_locs, - int total_count, - int* error_flag); - -__global__ void compute_msg_locations_from_occurrences_kernel( - repeated_occurrence const* occurrences, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - field_location* msg_locs, - cudf::size_type* msg_row_offsets, - int total_count, - int* error_flag); - -__global__ void extract_strided_locations_kernel(field_location const* nested_locations, - int field_idx, - int num_fields, - field_location* parent_locs, - int num_rows); - -__global__ void check_required_fields_kernel(field_location const* locations, - uint8_t const* is_required, - int num_fields, - int num_rows, - cudf::bitmask_type const* input_null_mask, - cudf::size_type input_offset, - field_location const* parent_locs, - bool* row_force_null, - int32_t const* top_row_indices, - int* error_flag); - inline void maybe_check_required_fields(field_location const* locations, std::vector const& field_indices, std::vector const& schema, @@ -1681,8 +751,8 @@ inline std::unique_ptr extract_typed_column( num_items, [&](float* out_ptr, bool* valid_ptr) { extract_fixed_kernel <<>>(message_data, loc_provider, @@ -1703,8 +773,8 @@ inline std::unique_ptr extract_typed_column( num_items, [&](double* out_ptr, bool* valid_ptr) { extract_fixed_kernel <<>>(message_data, loc_provider, @@ -1782,28 +852,28 @@ inline std::unique_ptr build_repeated_scalar_column( int encoding = field_desc.encoding; bool zigzag = - (encoding == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::ZIGZAG)); + (encoding == spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)); // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) // For integer types, use fixed kernel only if encoding is - // spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED) + // spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED) constexpr bool is_floating_point = std::is_same_v || std::is_same_v; bool use_fixed_kernel = is_floating_point || - (encoding == spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED)); + (encoding == spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED)); RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; if (use_fixed_kernel) { if constexpr (sizeof(T) == 4) { extract_fixed_kernel<<>>( message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { extract_fixed_kernel<<>>( message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } @@ -1845,4 +915,5 @@ inline std::unique_ptr build_repeated_scalar_column( num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); } -} // namespace spark_rapids_jni::protobuf_detail +} // namespace spark_rapids_jni::protobuf::detail + diff --git a/src/main/cpp/src/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu similarity index 94% rename from src/main/cpp/src/protobuf_kernels.cu rename to src/main/cpp/src/protobuf/protobuf_kernels.cu index 118c918ebd..ea7d42b6d9 100644 --- a/src/main/cpp/src/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -14,9 +14,9 @@ * limitations under the License. */ -#include "protobuf_common.cuh" +#include "protobuf/protobuf_device_helpers.cuh" -namespace spark_rapids_jni::protobuf_detail { +namespace spark_rapids_jni::protobuf::detail { // ============================================================================ // Pass 1: Scan all fields kernel - records (offset, length) for each field @@ -99,7 +99,8 @@ __global__ void scan_all_fields_kernel( // Record the location (relative to message start) int data_offset = static_cast(cur - bytes - start); - if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { + if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)) { // For length-delimited, record offset after length prefix and the data length uint64_t len; int len_bytes; @@ -165,9 +166,10 @@ __device__ bool count_repeated_element(uint8_t const* cur, repeated_field_info& info, int* error_flag) { - bool is_packed = - (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) && - expected_wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); + bool is_packed = (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN) && + expected_wt != spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)); if (!is_packed && wt != expected_wt) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -189,8 +191,8 @@ __device__ bool count_repeated_element(uint8_t const* cur, uint8_t const* packed_end = packed_start + packed_len; int count = 0; - if (expected_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)) { + if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::VARINT)) { uint8_t const* p = packed_start; while (p < packed_end) { uint64_t dummy; @@ -202,15 +204,15 @@ __device__ bool count_repeated_element(uint8_t const* cur, p += vbytes; count++; } - } else if (expected_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { + } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { if ((packed_len % 4) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; } count = static_cast(packed_len / 4); - } else if (expected_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)) { + } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT)) { if ((packed_len % 8) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -249,9 +251,10 @@ __device__ bool scan_repeated_element(uint8_t const* cur, int write_end, int* error_flag) { - bool is_packed = - (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) && - expected_wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); + bool is_packed = (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN) && + expected_wt != spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)); if (!is_packed && wt != expected_wt) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -272,8 +275,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, } uint8_t const* packed_end = packed_start + packed_len; - if (expected_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)) { + if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::VARINT)) { uint8_t const* p = packed_start; while (p < packed_end) { int32_t elem_offset = static_cast(p - msg_base); @@ -291,8 +294,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, write_idx++; p += vbytes; } - } else if (expected_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { + } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { if ((packed_len % 4) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -305,8 +308,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 4}; write_idx++; } - } else if (expected_wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)) { + } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT)) { if ((packed_len % 8) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -438,7 +441,8 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in // Check nested message fields at this depth auto handle_nested = [&](int i) { - if (wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { + if (wt != spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)) { set_error_once(error_flag, ERR_WIRE_TYPE); return false; } @@ -541,10 +545,11 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co int wt = tag.wire_type; auto try_scan = [&](int f) -> bool { - int target_wt = scan_descs[f].wire_type; - bool is_packed = - (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN) && - target_wt != spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); + int target_wt = scan_descs[f].wire_type; + bool is_packed = (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN) && + target_wt != spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)); if (is_packed || wt == target_wt) { return scan_repeated_element(cur, msg_end, @@ -657,7 +662,8 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, int data_offset = static_cast(cur - nested_start); - if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { + if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)) { uint64_t len; int len_bytes; if (!read_varint(cur, nested_end, len, len_bytes)) { @@ -760,7 +766,8 @@ __global__ void scan_repeated_message_children_kernel( } else { int data_offset = static_cast(cur - msg_start); - if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)) { + if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)) { uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { @@ -784,7 +791,8 @@ __global__ void scan_repeated_message_children_kernel( } else { // For varint/fixed types, store offset and estimated length int32_t data_length = 0; - if (wt == spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)) { + if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::VARINT)) { uint64_t dummy; int vbytes; if (!read_varint(cur, msg_end, dummy, vbytes)) { @@ -792,15 +800,15 @@ __global__ void scan_repeated_message_children_kernel( return; } data_length = vbytes; - } else if (wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)) { + } else if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { if (msg_end - cur < 4) { set_error_once(error_flag, ERR_FIXED_LEN); return; } data_length = 4; - } else if (wt == - spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)) { + } else if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT)) { if (msg_end - cur < 8) { set_error_once(error_flag, ERR_FIXED_LEN); return; @@ -1318,4 +1326,4 @@ __global__ void copy_enum_string_chars_kernel( } } -} // namespace spark_rapids_jni::protobuf_detail +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh new file mode 100644 index 0000000000..04aee0042c --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -0,0 +1,571 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf_device_helpers.cuh" + +namespace spark_rapids_jni::protobuf::detail { + +// ============================================================================ +// Pass 2: Extract data kernels +// ============================================================================ + +// ============================================================================ +// Data Extraction Location Providers +// ============================================================================ + +struct TopLevelLocationProvider { + cudf::size_type const* offsets; + cudf::size_type base_offset; + field_location const* locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto loc = locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (loc.offset >= 0) { data_offset = offsets[thread_idx] - base_offset + loc.offset; } + return loc; + } +}; + +struct RepeatedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + data_offset = row_offsets[occ.row_idx] - base_offset + occ.offset; + return {occ.offset, occ.length}; + } +}; + +struct NestedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto ploc = parent_locations[thread_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (ploc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + ploc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +struct NestedRepeatedLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* parent_locations; + repeated_occurrence const* occurrences; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto occ = occurrences[thread_idx]; + auto ploc = parent_locations[occ.row_idx]; + if (ploc.offset >= 0) { + data_offset = row_offsets[occ.row_idx] - base_offset + ploc.offset + occ.offset; + return {occ.offset, occ.length}; + } + data_offset = 0; + return {-1, 0}; + } +}; + +struct RepeatedMsgChildLocationProvider { + cudf::size_type const* row_offsets; + cudf::size_type base_offset; + field_location const* msg_locations; + field_location const* child_locations; + int field_idx; + int num_fields; + + __device__ inline field_location get(int thread_idx, int32_t& data_offset) const + { + auto mloc = msg_locations[thread_idx]; + auto cloc = child_locations[flat_index(static_cast(thread_idx), + static_cast(num_fields), + static_cast(field_idx))]; + if (mloc.offset >= 0 && cloc.offset >= 0) { + data_offset = row_offsets[thread_idx] - base_offset + mloc.offset + cloc.offset; + } else { + cloc.offset = -1; + } + return cloc; + } +}; + +template +__global__ void extract_varint_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + // For BOOL8 (uint8_t), protobuf spec says any non-zero varint is true. + // A raw static_cast would silently truncate values >= 256 to 0. + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (has_default) { + write_value(&out[idx], static_cast(default_value)); + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + uint8_t const* cur_end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, cur_end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + if (valid) valid[idx] = false; + return; + } + + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[idx], v); + if (valid) valid[idx] = true; +} + +template +__global__ void extract_fixed_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset < 0) { + if (has_default) { + out[idx] = default_value; + if (valid) valid[idx] = true; + } else { + if (valid) valid[idx] = false; + } + return; + } + + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == + spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + if (valid) valid[idx] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + + out[idx] = value; + if (valid) valid[idx] = true; +} + +// ============================================================================ +// Batched scalar extraction — one 2D kernel for N fields of the same type +// ============================================================================ + +struct batched_scalar_desc { + int loc_field_idx; // index into the locations array (column within d_locations) + void* output; // pre-allocated output buffer (T*) + bool* valid; // pre-allocated validity buffer + bool has_default; + int64_t default_int; + double default_float; +}; + +template +__global__ void extract_varint_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + auto const write_value = [](OutputType* dst, uint64_t val) { + if constexpr (std::is_same_v) { + *dst = static_cast(val != 0 ? 1 : 0); + } else { + *dst = static_cast(val); + } + }; + + if (loc.offset < 0) { + if (desc.has_default) { + write_value(&out[row], static_cast(desc.default_int)); + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + uint8_t const* end = cur + loc.length; + + uint64_t v; + int n; + if (!read_varint(cur, end, v, n)) { + set_error_once(error_flag, ERR_VARINT); + desc.valid[row] = false; + return; + } + if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } + write_value(&out[row], v); + desc.valid[row] = true; +} + +template +__global__ void extract_fixed_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) +{ + int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + int fi = static_cast(blockIdx.y); + if (row >= num_rows || fi >= num_descs) return; + + auto const& desc = descs[fi]; + auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; + auto* out = static_cast(desc.output); + + if (loc.offset < 0) { + if (desc.has_default) { + if constexpr (std::is_integral_v) { + out[row] = static_cast(desc.default_int); + } else { + out[row] = static_cast(desc.default_float); + } + desc.valid[row] = true; + } else { + desc.valid[row] = false; + } + return; + } + + int32_t data_offset = row_offsets[row] - base_offset + loc.offset; + uint8_t const* cur = message_data + data_offset; + OutputType value; + + if constexpr (WT == + spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + if (loc.length < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint32_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } else { + if (loc.length < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + desc.valid[row] = false; + return; + } + uint64_t raw = load_le(cur); + memcpy(&value, &raw, sizeof(value)); + } + out[row] = value; + desc.valid[row] = true; +} + +// ============================================================================ + +template +__global__ void extract_lengths_kernel(LocationProvider loc_provider, + int total_items, + int32_t* out_lengths, + bool has_default = false, + int32_t default_length = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + if (loc.offset >= 0) { + out_lengths[idx] = loc.length; + } else if (has_default) { + out_lengths[idx] = default_length; + } else { + out_lengths[idx] = 0; + } +} +template +__global__ void copy_varlen_data_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + cudf::size_type const* output_offsets, + char* output_chars, + int* error_flag, + bool has_default = false, + uint8_t const* default_chars = nullptr, + int default_len = 0) +{ + auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (idx >= total_items) return; + + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + + auto out_start = output_offsets[idx]; + + if (loc.offset < 0) { + if (has_default && default_len > 0) { + memcpy(output_chars + out_start, default_chars, default_len); + } + return; + } + + uint8_t const* src = message_data + data_offset; + memcpy(output_chars + out_start, src, loc.length); +} + +// ============================================================================ +// Forward declarations of non-template __global__ kernels +// ============================================================================ + +__global__ void scan_all_fields_kernel(cudf::column_device_view const d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag, + bool* row_has_invalid_data); + +__global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx = nullptr, + int fn_to_rep_size = 0, + int const* fn_to_nested_idx = nullptr, + int fn_to_nested_size = 0); + +__global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx = nullptr, + int fn_to_desc_size = 0); + +__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag); + +__global__ void scan_repeated_message_children_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag, + int const* child_lookup = nullptr, + int child_lookup_size = 0); + +__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag); + +__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag); + +__global__ void compute_nested_struct_locations_kernel(field_location const* child_locs, + field_location const* msg_locs, + cudf::size_type const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + cudf::size_type* nested_row_offsets, + int total_count, + int* error_flag); + +__global__ void compute_grandchild_parent_locations_kernel(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows, + int* error_flag); + +__global__ void compute_virtual_parents_for_nested_repeated_kernel( + repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count, + int* error_flag); + +__global__ void compute_msg_locations_from_occurrences_kernel( + repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + cudf::size_type* msg_row_offsets, + int total_count, + int* error_flag); + +__global__ void extract_strided_locations_kernel(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows); + +__global__ void check_required_fields_kernel(field_location const* locations, + uint8_t const* is_required, + int num_fields, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag); + +__global__ void validate_enum_values_kernel(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows); + +__global__ void compute_enum_string_lengths_kernel(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows); + +__global__ void copy_enum_string_chars_kernel(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows); + +} // namespace spark_rapids_jni::protobuf::detail + diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh new file mode 100644 index 0000000000..a6fb3114d3 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf.hpp" + +#include "protobuf/protobuf.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +// Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. +constexpr int MAX_VARINT_BYTES = 10; + +// CUDA kernel launch configuration. +constexpr int THREADS_PER_BLOCK = 256; + +// Error codes for kernel error reporting. +constexpr int ERR_BOUNDS = 1; +constexpr int ERR_VARINT = 2; +constexpr int ERR_FIELD_NUMBER = 3; +constexpr int ERR_WIRE_TYPE = 4; +constexpr int ERR_OVERFLOW = 5; +constexpr int ERR_FIELD_SIZE = 6; +constexpr int ERR_SKIP = 7; +constexpr int ERR_FIXED_LEN = 8; +constexpr int ERR_REQUIRED = 9; +constexpr int ERR_SCHEMA_TOO_LARGE = 10; +constexpr int ERR_MISSING_ENUM_META = 11; +constexpr int ERR_REPEATED_COUNT_MISMATCH = 12; + +// Maximum supported nesting depth for recursive struct decoding. +constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; + +// Threshold for using a direct-mapped lookup table for field_number -> field_index. +// Field numbers above this threshold fall back to linear search. +constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; + +/** + * Structure to record field location within a message. + * offset < 0 means field was not found. + */ +struct field_location { + int32_t offset; // Offset of field data within the message (-1 if not found) + int32_t length; // Length of field data in bytes +}; + +/** + * Field descriptor passed to the scanning kernel. + */ +struct field_descriptor { + int field_number; // Protobuf field number + int expected_wire_type; // Expected wire type for this field + bool is_repeated; // Repeated children are scanned via count/scan kernels +}; + +/** + * Information about repeated field occurrences in a row. + */ +struct repeated_field_info { + int32_t count; // Number of occurrences in this row + int32_t total_length; // Total bytes for all occurrences (for varlen fields) +}; + +/** + * Location of a single occurrence of a repeated field. + */ +struct repeated_occurrence { + int32_t row_idx; // Which row this occurrence belongs to + int32_t offset; // Offset within the message + int32_t length; // Length of the field data +}; + +/** + * Per-field descriptor passed to the combined occurrence scan kernel. + * Contains device pointers so the kernel can write to each field's output. + */ +struct repeated_field_scan_desc { + int field_number; + int wire_type; + int32_t const* row_offsets; // Pre-computed prefix-sum offsets [num_rows + 1] + repeated_occurrence* occurrences; // Output buffer [total_count] +}; + +/** + * Device-side descriptor for nested schema fields. + */ +struct device_nested_field_descriptor { + int field_number; + int parent_idx; + int depth; + int wire_type; + int output_type_id; + int encoding; + bool is_repeated; + bool is_required; + bool has_default_value; + + device_nested_field_descriptor() = default; + + explicit device_nested_field_descriptor(spark_rapids_jni::protobuf::nested_field_descriptor const& src) + : field_number(src.field_number), + parent_idx(src.parent_idx), + depth(src.depth), + wire_type(static_cast(src.wire_type)), + output_type_id(static_cast(src.output_type)), + encoding(static_cast(src.encoding)), + is_repeated(src.is_repeated), + is_required(src.is_required), + has_default_value(src.has_default_value) + { + } +}; + +} // namespace spark_rapids_jni::protobuf::detail + From 6bcac2f40abaef44a621e118dae6fe66cd99ed0a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Mar 2026 10:16:21 +0800 Subject: [PATCH 094/107] style Signed-off-by: Haoyang Li --- .../src/protobuf/protobuf_device_helpers.cuh | 46 ++++++++++++------- .../src/protobuf/protobuf_host_helpers.hpp | 17 ++++--- .../cpp/src/protobuf/protobuf_kernels.cuh | 9 ++-- src/main/cpp/src/protobuf/protobuf_types.cuh | 6 +-- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh index 894609b970..09fc15889c 100644 --- a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -68,7 +68,8 @@ inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stre __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { switch (wt) { - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::VARINT): { + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::VARINT): { // Need to scan to find the end of varint int count = 0; while (cur < end && count < MAX_VARINT_BYTES) { @@ -77,15 +78,18 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; // Invalid varint } - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I64BIT): + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT): // Check if there's enough data for 8 bytes if (end - cur < 8) return -1; return 8; - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT): + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT): // Check if there's enough data for 4 bytes if (end - cur < 4) return -1; return 4; - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::LEN): { + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN): { uint64_t len; int n; if (!read_varint(cur, end, len, n)) return -1; @@ -93,7 +97,8 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return -1; return n + static_cast(len); } - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP): { + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::SGROUP): { auto const* start = cur; int depth = 1; while (cur < end && depth > 0) { @@ -103,27 +108,30 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con cur += key_bytes; int inner_wt = static_cast(key & 0x7); - if (inner_wt == - spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { + if (inner_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { --depth; if (depth == 0) { return static_cast(cur - start); } - } else if (inner_wt == - spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP)) { + } else if (inner_wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::SGROUP)) { if (++depth > 32) return -1; } else { int inner_size = -1; switch (inner_wt) { - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::VARINT): { + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::VARINT): { uint64_t dummy; int vbytes; if (!read_varint(cur, end, dummy, vbytes)) return -1; inner_size = vbytes; break; } - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I64BIT): + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I64BIT): inner_size = 8; break; - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::LEN): { + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN): { uint64_t len; int len_bytes; if (!read_varint(cur, end, len, len_bytes)) return -1; @@ -131,7 +139,8 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con inner_size = len_bytes + static_cast(len); break; } - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT): + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT): inner_size = 4; break; default: return -1; @@ -142,7 +151,9 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; } - case spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::EGROUP): return 0; + case spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::EGROUP): + return 0; default: return -1; } } @@ -156,7 +167,8 @@ __device__ inline bool skip_field(uint8_t const* cur, // get_wire_type_size(spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP)). // The scan/count kernels should never accept it as a standalone field because Spark CPU treats // unmatched end-groups as malformed protobuf. - if (wt == spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { + if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { return false; } @@ -175,7 +187,8 @@ __device__ inline bool skip_field(uint8_t const* cur, __device__ inline bool get_field_data_location( uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) { - if (wt == spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + if (wt == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::LEN)) { // For length-delimited, read the length prefix uint64_t len; int len_bytes; @@ -297,4 +310,3 @@ __device__ __forceinline__ int lookup_field(int field_number, } } // namespace spark_rapids_jni::protobuf::detail - diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index e41a9eab99..2f426e0bf1 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -107,8 +107,8 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, int* error_ptr, rmm::cuda_stream_view stream) { - if (enable_zigzag && - encoding == spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)) { + if (enable_zigzag && encoding == spark_rapids_jni::protobuf::encoding_value( + spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)) { extract_varint_kernel <<>>(message_data, loc_provider, @@ -118,8 +118,8 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, error_ptr, has_default, default_value); - } else if (encoding == - spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED)) { + } else if (encoding == spark_rapids_jni::protobuf::encoding_value( + spark_rapids_jni::protobuf::proto_encoding::FIXED)) { if constexpr (sizeof(T) == 4) { extract_fixed_kernel build_repeated_scalar_column( auto const blocks = static_cast((total_count + threads - 1u) / threads); int encoding = field_desc.encoding; - bool zigzag = - (encoding == spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)); + bool zigzag = (encoding == spark_rapids_jni::protobuf::encoding_value( + spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)); // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) // For integer types, use fixed kernel only if encoding is // spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED) constexpr bool is_floating_point = std::is_same_v || std::is_same_v; bool use_fixed_kernel = - is_floating_point || - (encoding == spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED)); + is_floating_point || (encoding == spark_rapids_jni::protobuf::encoding_value( + spark_rapids_jni::protobuf::proto_encoding::FIXED)); RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; if (use_fixed_kernel) { @@ -916,4 +916,3 @@ inline std::unique_ptr build_repeated_scalar_column( } } // namespace spark_rapids_jni::protobuf::detail - diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index 04aee0042c..622f2e22a9 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -204,8 +204,8 @@ __global__ void extract_fixed_kernel(uint8_t const* message_data, uint8_t const* cur = message_data + data_offset; OutputType value; - if constexpr (WT == - spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + if constexpr (WT == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); if (valid) valid[idx] = false; @@ -330,8 +330,8 @@ __global__ void extract_fixed_batched_kernel(uint8_t const* message_data, uint8_t const* cur = message_data + data_offset; OutputType value; - if constexpr (WT == - spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + if constexpr (WT == spark_rapids_jni::protobuf::wire_type_value( + spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); desc.valid[row] = false; @@ -568,4 +568,3 @@ __global__ void copy_enum_string_chars_kernel(int32_t const* values, int num_rows); } // namespace spark_rapids_jni::protobuf::detail - diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh index a6fb3114d3..c41adb1485 100644 --- a/src/main/cpp/src/protobuf/protobuf_types.cuh +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -18,8 +18,6 @@ #include "protobuf/protobuf.hpp" -#include "protobuf/protobuf.hpp" - #include #include #include @@ -144,7 +142,8 @@ struct device_nested_field_descriptor { device_nested_field_descriptor() = default; - explicit device_nested_field_descriptor(spark_rapids_jni::protobuf::nested_field_descriptor const& src) + explicit device_nested_field_descriptor( + spark_rapids_jni::protobuf::nested_field_descriptor const& src) : field_number(src.field_number), parent_idx(src.parent_idx), depth(src.depth), @@ -159,4 +158,3 @@ struct device_nested_field_descriptor { }; } // namespace spark_rapids_jni::protobuf::detail - From ff83290a3a0f4353b8aaaed4d03e3e0468872c2b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Mar 2026 16:08:11 +0800 Subject: [PATCH 095/107] address comments from part0 Signed-off-by: Haoyang Li --- src/main/cpp/benchmarks/protobuf_decode.cu | 6 +- src/main/cpp/src/ProtobufJni.cpp | 33 ++----- src/main/cpp/src/protobuf/protobuf.cu | 86 ++++++++--------- src/main/cpp/src/protobuf/protobuf.hpp | 4 +- .../cpp/src/protobuf/protobuf_builders.cu | 93 +++++++++---------- .../src/protobuf/protobuf_device_helpers.cuh | 8 +- .../src/protobuf/protobuf_host_helpers.hpp | 85 +++++++++++------ src/main/cpp/src/protobuf/protobuf_kernels.cu | 2 +- .../cpp/src/protobuf/protobuf_kernels.cuh | 11 ++- src/main/cpp/src/protobuf/protobuf_types.cuh | 37 +------- 10 files changed, 168 insertions(+), 197 deletions(-) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index b54263a759..18b8b93199 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -1151,7 +1151,7 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) work.reserve(num_repeated_children); for (int ri = 0; ri < num_repeated_children; ri++) { auto& w = *work.emplace_back(std::make_unique(num_rows, stream, mr)); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), w.counts.data(), @@ -1159,7 +1159,7 @@ static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) d_rep_info.data(), ri, num_repeated_children}); CUDF_CUDA_TRY(cudaMemsetAsync(w.offsets.data(), 0, sizeof(int32_t), stream.value())); thrust::inclusive_scan( - rmm::exec_policy(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); + rmm::exec_policy_nosync(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); CUDF_CUDA_TRY(cudaMemcpyAsync(&w.total_count, w.offsets.data() + num_rows, sizeof(int32_t), @@ -1274,7 +1274,7 @@ static void BM_protobuf_repeated_child_string_build(nvbench::state& state) auto& c = *children[i]; rmm::device_uvector list_offs(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), c.counts.begin(), c.counts.end(), list_offs.begin(), 0); + rmm::exec_policy_nosync(stream), c.counts.begin(), c.counts.end(), list_offs.begin(), 0); CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, &c.total_count, sizeof(int32_t), diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 4bfc9a9937..796ca3239a 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -15,12 +15,9 @@ */ #include "cudf_jni_apis.hpp" -#include "dtype_utils.hpp" #include "protobuf/protobuf.hpp" -#include #include -#include extern "C" { @@ -45,22 +42,12 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, jobjectArray enum_names, jboolean fail_on_errors) { - JNI_NULL_CHECK(env, binary_input_view, "binary_input_view is null", 0); - JNI_NULL_CHECK(env, field_numbers, "field_numbers is null", 0); - JNI_NULL_CHECK(env, parent_indices, "parent_indices is null", 0); - JNI_NULL_CHECK(env, depth_levels, "depth_levels is null", 0); - JNI_NULL_CHECK(env, wire_types, "wire_types is null", 0); - JNI_NULL_CHECK(env, output_type_ids, "output_type_ids is null", 0); - JNI_NULL_CHECK(env, encodings, "encodings is null", 0); - JNI_NULL_CHECK(env, is_repeated, "is_repeated is null", 0); - JNI_NULL_CHECK(env, is_required, "is_required is null", 0); - JNI_NULL_CHECK(env, has_default_value, "has_default_value is null", 0); - JNI_NULL_CHECK(env, default_ints, "default_ints is null", 0); - JNI_NULL_CHECK(env, default_floats, "default_floats is null", 0); - JNI_NULL_CHECK(env, default_bools, "default_bools is null", 0); - JNI_NULL_CHECK(env, default_strings, "default_strings is null", 0); - JNI_NULL_CHECK(env, enum_valid_values, "enum_valid_values is null", 0); - JNI_NULL_CHECK(env, enum_names, "enum_names is null", 0); + auto const all_inputs_valid = binary_input_view && field_numbers && parent_indices && + depth_levels && wire_types && output_type_ids && encodings && + is_repeated && is_required && has_default_value && default_ints && + default_floats && default_bools && default_strings && + enum_valid_values && enum_names; + JNI_NULL_CHECK(env, all_inputs_valid, "one or more input arrays are null", 0); JNI_TRY { @@ -113,13 +100,6 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, n_has_default[i] != 0}); } - // Build output types - std::vector schema_output_types; - schema_output_types.reserve(num_fields); - for (int i = 0; i < num_fields; ++i) { - schema_output_types.emplace_back(static_cast(n_output_type_ids[i])); - } - // Convert boolean arrays std::vector default_bool_values; default_bool_values.reserve(num_fields); @@ -216,7 +196,6 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } spark_rapids_jni::protobuf::ProtobufDecodeContext context{std::move(schema), - std::move(schema_output_types), std::move(default_int_values), std::move(default_float_values), std::move(default_bool_values), diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index e499e165fd..bdc6466a5c 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -173,9 +173,6 @@ bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_typ void validate_decode_context(ProtobufDecodeContext const& context) { auto const num_fields = context.schema.size(); - CUDF_EXPECTS(context.schema_output_types.size() == num_fields, - "protobuf decode context: schema_output_types size mismatch", - std::invalid_argument); CUDF_EXPECTS(context.default_ints.size() == num_fields, "protobuf decode context: default_ints size mismatch", std::invalid_argument); @@ -198,11 +195,7 @@ void validate_decode_context(ProtobufDecodeContext const& context) std::set> seen_field_numbers; for (size_t i = 0; i < num_fields; ++i) { auto const& field = context.schema[i]; - auto const& type = context.schema_output_types[i]; - CUDF_EXPECTS( - type.id() == field.output_type, - "protobuf decode context: schema_output_types id mismatch at field " + std::to_string(i), - std::invalid_argument); + auto const type = cudf::data_type{field.output_type}; CUDF_EXPECTS(field.field_number > 0 && field.field_number <= MAX_FIELD_NUMBER, "protobuf decode context: invalid field number at field " + std::to_string(i), std::invalid_argument); @@ -227,7 +220,7 @@ void validate_decode_context(ProtobufDecodeContext const& context) CUDF_EXPECTS(field.depth == parent.depth + 1, "protobuf decode context: child depth mismatch at field " + std::to_string(i), std::invalid_argument); - CUDF_EXPECTS(context.schema_output_types[field.parent_idx].id() == cudf::type_id::STRUCT, + CUDF_EXPECTS(context.schema[field.parent_idx].output_type == cudf::type_id::STRUCT, "protobuf decode context: parent must be STRUCT at field " + std::to_string(i), std::invalid_argument); } @@ -285,7 +278,7 @@ ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, { auto const idx = static_cast(schema_idx); return ProtobufFieldMetaView{context.schema.at(idx), - context.schema_output_types.at(idx), + cudf::data_type{context.schema.at(idx).output_type}, context.default_ints.at(idx), context.default_floats.at(idx), context.default_bools.at(idx), @@ -300,15 +293,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_async_resource_ref mr) { validate_decode_context(context); - auto const& schema = context.schema; - auto const& schema_output_types = context.schema_output_types; - auto const& default_ints = context.default_ints; - auto const& default_floats = context.default_floats; - auto const& default_bools = context.default_bools; - auto const& default_strings = context.default_strings; - auto const& enum_valid_values = context.enum_valid_values; - auto const& enum_names = context.enum_names; - bool fail_on_errors = context.fail_on_errors; + auto const& schema = context.schema; + auto const& default_ints = context.default_ints; + auto const& default_floats = context.default_floats; + auto const& default_bools = context.default_bools; + auto const& default_strings = context.default_strings; + auto const& enum_valid_values = context.enum_valid_values; + auto const& enum_names = context.enum_names; + bool fail_on_errors = context.fail_on_errors; CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, "binary_input must be a LIST column"); cudf::lists_column_view const in_list(binary_input); @@ -324,21 +316,21 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& std::vector> empty_children; for (int i = 0; i < num_fields; i++) { if (schema[i].parent_idx == -1) { - auto field_type = schema_output_types[i]; + auto field_type = cudf::data_type{schema[i].output_type}; if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { // Repeated message field - build empty LIST with proper struct element rmm::device_uvector offsets(1, stream, mr); CUDF_CUDA_TRY(cudaMemsetAsync(offsets.data(), 0, sizeof(int32_t), stream.value())); auto offsets_col = std::make_unique( cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); - auto empty_struct = make_empty_struct_column_with_schema( - schema, schema_output_types, i, num_fields, stream, mr); + auto empty_struct = + make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr); empty_children.push_back(cudf::make_lists_column( 0, std::move(offsets_col), std::move(empty_struct), 0, rmm::device_buffer{})); } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { // Non-repeated nested message field - empty_children.push_back(make_empty_struct_column_with_schema( - schema, schema_output_types, i, num_fields, stream, mr)); + empty_children.push_back( + make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr)); } else { empty_children.push_back(make_empty_column_safe(field_type, stream, mr)); } @@ -581,7 +573,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& for (int i = 0; i < num_scalar; i++) { int si = scalar_field_indices[i]; - auto tid = schema_output_types[si].id(); + auto tid = cudf::data_type{schema[si].output_type}.id(); auto enc = schema[si].encoding; bool zz = (enc == proto_encoding::ZIGZAG); @@ -643,7 +635,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto& bp = *bufs.emplace_back(std::make_unique(stream, mr)); bp.valid = rmm::device_uvector(std::max(1, num_rows), stream, mr); // BOOL8 default comes from default_bools (converted to 0/1 int) - bool is_bool = (schema_output_types[si].id() == cudf::type_id::BOOL8); + bool is_bool = (cudf::data_type{schema[si].output_type}.id() == cudf::type_id::BOOL8); int64_t def_i = hd ? (is_bool ? (default_bools[si] ? 1 : 0) : default_ints[si]) : 0; h_descs[j] = {li, nullptr, bp.valid.data(), hd, def_i, hd ? default_floats[si] : 0.0}; } @@ -654,7 +646,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Build columns for (int j = 0; j < nf; j++) { int si = scalar_field_indices[idxs[j]]; - auto dt = schema_output_types[si]; + auto dt = cudf::data_type{schema[si].output_type}; auto& bp = *bufs[j]; auto [mask, null_count] = make_null_mask_from_valid(bp.valid, stream, mr); column_map[si] = std::make_unique( @@ -922,7 +914,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto& w = *rep_work.emplace_back( std::make_unique(schema_idx, num_rows, stream, mr)); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), w.counts.data(), @@ -930,7 +922,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& CUDF_CUDA_TRY(cudaMemsetAsync(w.offsets.data(), 0, sizeof(int32_t), stream.value())); thrust::inclusive_scan( - rmm::exec_policy(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); + rmm::exec_policy_nosync(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); CUDF_CUDA_TRY(cudaMemcpyAsync(&w.total_count, w.offsets.data() + num_rows, @@ -997,7 +989,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& for (int ri = 0; ri < num_repeated; ri++) { auto& w = *rep_work[ri]; int schema_idx = w.schema_idx; - auto element_type = schema_output_types[schema_idx]; + auto element_type = cudf::data_type{schema[schema_idx].output_type}; int32_t total_count = w.total_count; auto& d_field_counts = w.counts; @@ -1005,7 +997,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto& d_occurrences = *w.occurrences; // Build the appropriate column type based on element type - auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); + auto child_type_id = h_device_schema[schema_idx].output_type; // The output_type in schema is the LIST type, but we need element type // For repeated int32, output_type should indicate the element is INT32 @@ -1138,8 +1130,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& mr); } else { set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); - column_map[schema_idx] = - make_null_column(schema_output_types[schema_idx], num_rows, stream, mr); + column_map[schema_idx] = make_null_column( + cudf::data_type{schema[schema_idx].output_type}, num_rows, stream, mr); } } else { column_map[schema_idx] = build_repeated_string_column(binary_input, @@ -1177,8 +1169,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Repeated message field - ArrayType(StructType) auto child_field_indices = find_child_field_indices(schema, num_fields, schema_idx); if (child_field_indices.empty()) { - auto empty_struct_child = make_empty_struct_column_with_schema( - schema, schema_output_types, schema_idx, num_fields, stream, mr); + auto empty_struct_child = + make_empty_struct_column_with_schema(schema, schema_idx, num_fields, stream, mr); column_map[schema_idx] = make_null_list_column_with_child( std::move(empty_struct_child), num_rows, stream, mr); } else { @@ -1194,7 +1186,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& num_rows, h_device_schema, child_field_indices, - schema_output_types, default_ints, default_floats, default_bools, @@ -1218,7 +1209,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } else { // All rows have count=0 - create list of empty elements rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), @@ -1227,13 +1218,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Build appropriate empty child column std::unique_ptr child_col; - auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); + auto child_type_id = h_device_schema[schema_idx].output_type; if (child_type_id == cudf::type_id::STRUCT) { // Use helper to build empty struct with proper nested structure - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, schema_idx, num_fields, stream, mr); + child_col = + make_empty_struct_column_with_schema(schema, schema_idx, num_fields, stream, mr); } else { - child_col = make_empty_column_safe(schema_output_types[schema_idx], stream, mr); + child_col = + make_empty_column_safe(cudf::data_type{schema[schema_idx].output_type}, stream, mr); } auto const input_null_count = binary_input.null_count(); @@ -1262,8 +1254,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& if (child_field_indices.empty()) { // No child fields - create empty struct - column_map[parent_schema_idx] = - make_null_column(schema_output_types[parent_schema_idx], num_rows, stream, mr); + column_map[parent_schema_idx] = make_null_column( + cudf::data_type{schema[parent_schema_idx].output_type}, num_rows, stream, mr); continue; } @@ -1280,7 +1272,6 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& child_field_indices, schema, num_fields, - schema_output_types, default_ints, default_floats, default_bools, @@ -1306,11 +1297,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& top_level_children.push_back(std::move(column_map[i])); } else { if (schema[i].is_repeated) { - auto const element_type = schema_output_types[i]; + auto const element_type = cudf::data_type{schema[i].output_type}; std::unique_ptr empty_child; if (element_type.id() == cudf::type_id::STRUCT) { - empty_child = make_empty_struct_column_with_schema( - schema, schema_output_types, i, num_fields, stream, mr); + empty_child = make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr); } else { empty_child = make_empty_column_safe(element_type, stream, mr); } @@ -1318,7 +1308,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& make_null_list_column_with_child(std::move(empty_child), num_rows, stream, mr)); } else { top_level_children.push_back( - make_null_column(schema_output_types[i], num_rows, stream, mr)); + make_null_column(cudf::data_type{schema[i].output_type}, num_rows, stream, mr)); } } } diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp index 452220d399..144fd35817 100644 --- a/src/main/cpp/src/protobuf/protobuf.hpp +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -74,7 +73,6 @@ struct nested_field_descriptor { struct ProtobufDecodeContext { std::vector schema; - std::vector schema_output_types; std::vector default_ints; std::vector default_floats; std::vector default_bools; @@ -86,7 +84,7 @@ struct ProtobufDecodeContext { struct ProtobufFieldMetaView { nested_field_descriptor const& schema; - cudf::data_type const& output_type; + cudf::data_type output_type; int64_t default_int; double default_float; bool default_bool; diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 756948a206..19dda8a0ec 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -17,6 +17,7 @@ #include "protobuf/protobuf_host_helpers.hpp" #include +#include namespace spark_rapids_jni::protobuf::detail { @@ -50,7 +51,7 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( rmm::device_uvector d_lengths(total_count, stream, mr); thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(total_count), d_lengths.data(), @@ -67,7 +68,7 @@ inline std::unique_ptr build_repeated_msg_child_varlen_column( rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(total_count), d_valid.data(), @@ -143,7 +144,7 @@ std::unique_ptr make_null_column(cudf::data_type dtype, return cudf::make_fixed_width_column(dtype, num_rows, cudf::mask_state::ALL_NULL, stream, mr); case cudf::type_id::STRING: { rmm::device_uvector pairs(num_rows, stream, mr); - thrust::fill(rmm::exec_policy(stream), + thrust::fill(rmm::exec_policy_nosync(stream), pairs.data(), pairs.end(), cudf::strings::detail::string_index_pair{nullptr, 0}); @@ -210,7 +211,7 @@ std::unique_ptr make_null_list_column_with_child( rmm::device_async_resource_ref mr) { rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), @@ -354,7 +355,7 @@ std::unique_ptr build_enum_string_column( auto const blocks = static_cast((num_rows + threads - 1u) / threads); auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); rmm::device_uvector d_item_has_invalid_enum(num_rows, stream, mr); - thrust::fill(rmm::exec_policy(stream), + thrust::fill(rmm::exec_policy_nosync(stream), d_item_has_invalid_enum.begin(), d_item_has_invalid_enum.end(), false); @@ -416,7 +417,7 @@ inline std::unique_ptr build_repeated_msg_child_enum_string_column 0); rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); - thrust::fill(rmm::exec_policy(stream), + thrust::fill(rmm::exec_policy_nosync(stream), d_elem_has_invalid_enum.begin(), d_elem_has_invalid_enum.end(), false); @@ -475,7 +476,7 @@ std::unique_ptr build_repeated_enum_string_column( // (elem_valid was already populated by extract_varint_kernel: true for success, false for // failure) rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); - thrust::fill(rmm::exec_policy(stream), + thrust::fill(rmm::exec_policy_nosync(stream), d_elem_has_invalid_enum.begin(), d_elem_has_invalid_enum.end(), false); @@ -488,7 +489,7 @@ std::unique_ptr build_repeated_enum_string_column( total_count); rmm::device_uvector d_top_row_indices(total_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), d_occurrences.begin(), d_occurrences.end(), d_top_row_indices.begin(), @@ -507,9 +508,9 @@ std::unique_ptr build_repeated_enum_string_column( // Build the final LIST column from the per-row counts and decoded child strings. rmm::device_uvector lo(num_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), lo.begin(), 0); + rmm::exec_policy_nosync(stream), d_field_counts.begin(), d_field_counts.end(), lo.begin(), 0); int32_t tc_i32 = static_cast(total_count); - thrust::fill_n(rmm::exec_policy(stream), lo.data() + num_rows, 1, tc_i32); + thrust::fill_n(rmm::exec_policy_nosync(stream), lo.data() + num_rows, 1, tc_i32); auto list_offs_col = std::make_unique( cudf::data_type{cudf::type_id::INT32}, num_rows + 1, lo.release(), rmm::device_buffer{}, 0); @@ -547,7 +548,7 @@ std::unique_ptr build_repeated_string_column( if (total_count == 0) { // All rows have count=0, but we still need to check input nulls rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), @@ -573,11 +574,14 @@ std::unique_ptr build_repeated_string_column( } rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); int32_t total_count_i32 = static_cast(total_count); - thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_rows, 1, total_count_i32); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); // Extract string lengths from occurrences rmm::device_uvector str_lengths(total_count, stream, mr); @@ -650,7 +654,6 @@ std::unique_ptr build_nested_struct_column( std::vector const& child_field_indices, std::vector const& schema, int num_fields, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -679,7 +682,6 @@ std::unique_ptr build_repeated_child_list_column( int child_schema_idx, std::vector const& schema, int num_fields, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -707,7 +709,6 @@ std::unique_ptr build_repeated_struct_column( int num_rows, std::vector const& h_device_schema, std::vector const& child_field_indices, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -726,7 +727,7 @@ std::unique_ptr build_repeated_struct_column( if (total_count == 0 || num_child_fields == 0) { // All rows have count=0 or no child fields - return list of empty structs rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), @@ -737,11 +738,11 @@ std::unique_ptr build_repeated_struct_column( int num_schema_fields = static_cast(h_device_schema.size()); std::vector> empty_struct_children; for (int child_schema_idx : child_field_indices) { - auto child_type = schema_output_types[child_schema_idx]; + auto child_type = cudf::data_type{schema[child_schema_idx].output_type}; std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { child_col = make_empty_struct_column_with_schema( - h_device_schema, schema_output_types, child_schema_idx, num_schema_fields, stream, mr); + h_device_schema, child_schema_idx, num_schema_fields, stream, mr); } else { child_col = make_empty_column_safe(child_type, stream, mr); } @@ -767,11 +768,14 @@ std::unique_ptr build_repeated_struct_column( } rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); int32_t total_count_i32 = static_cast(total_count); - thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_rows, 1, total_count_i32); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); // Build child field descriptors for scanning within each message occurrence std::vector h_child_descs(num_child_fields); @@ -816,7 +820,7 @@ std::unique_ptr build_repeated_struct_column( d_error_top.data()); } rmm::device_uvector d_top_row_indices(total_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), d_occurrences.data(), d_occurrences.end(), d_top_row_indices.data(), @@ -869,7 +873,7 @@ std::unique_ptr build_repeated_struct_column( int num_schema_fields = static_cast(h_device_schema.size()); for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; + auto const dt = cudf::data_type{schema[child_schema_idx].output_type}; auto const enc = h_device_schema[child_schema_idx].encoding; bool has_def = h_device_schema[child_schema_idx].has_default_value; bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; @@ -884,7 +888,6 @@ std::unique_ptr build_repeated_struct_column( child_schema_idx, schema, num_schema_fields, - schema_output_types, default_ints, default_floats, default_bools, @@ -1032,7 +1035,6 @@ std::unique_ptr build_repeated_struct_column( grandchild_indices, schema, num_schema_fields, - schema_output_types, default_ints, default_floats, default_bools, @@ -1091,7 +1093,6 @@ std::unique_ptr build_nested_struct_column( std::vector const& child_field_indices, std::vector const& schema, int num_fields, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -1113,11 +1114,11 @@ std::unique_ptr build_nested_struct_column( if (num_rows == 0) { std::vector> empty_children; for (int child_schema_idx : child_field_indices) { - auto child_type = schema_output_types[child_schema_idx]; + auto child_type = cudf::data_type{schema[child_schema_idx].output_type}; std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_fields, stream, mr); + child_col = + make_empty_struct_column_with_schema(schema, child_schema_idx, num_fields, stream, mr); } else { child_col = make_empty_column_safe(child_type, stream, mr); } @@ -1180,7 +1181,7 @@ std::unique_ptr build_nested_struct_column( std::vector> struct_children; for (int ci = 0; ci < num_child_fields; ci++) { int child_schema_idx = child_field_indices[ci]; - auto const dt = schema_output_types[child_schema_idx]; + auto const dt = cudf::data_type{schema[child_schema_idx].output_type}; auto const enc = static_cast(schema[child_schema_idx].encoding); bool has_def = schema[child_schema_idx].has_default_value; bool is_repeated = schema[child_schema_idx].is_repeated; @@ -1195,7 +1196,6 @@ std::unique_ptr build_nested_struct_column( child_schema_idx, schema, num_fields, - schema_output_types, default_ints, default_floats, default_bools, @@ -1399,7 +1399,6 @@ std::unique_ptr build_nested_struct_column( gc_indices, schema, num_fields, - schema_output_types, default_ints, default_floats, default_bools, @@ -1422,7 +1421,7 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector struct_valid((num_rows > 0 ? num_rows : 1), stream, mr); thrust::transform( - rmm::exec_policy(stream), + rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), struct_valid.data(), @@ -1446,7 +1445,6 @@ std::unique_ptr build_repeated_child_list_column( int child_schema_idx, std::vector const& schema, int num_fields, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -1475,7 +1473,7 @@ std::unique_ptr build_repeated_child_list_column( device_nested_field_descriptor rep_desc; rep_desc.field_number = schema[child_schema_idx].field_number; rep_desc.wire_type = static_cast(schema[child_schema_idx].wire_type); - rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.output_type = schema[child_schema_idx].output_type; rep_desc.is_repeated = true; rep_desc.parent_idx = -1; rep_desc.depth = 0; @@ -1509,17 +1507,18 @@ std::unique_ptr build_repeated_child_list_column( d_error.data()); rmm::device_uvector d_rep_counts(num_parent_rows, stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), d_rep_info.data(), d_rep_info.end(), d_rep_counts.data(), [] __device__(repeated_field_info const& info) { return info.count; }); int total_rep_count = - thrust::reduce(rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), 0); + thrust::reduce(rmm::exec_policy_nosync(stream), d_rep_counts.data(), d_rep_counts.end(), 0); if (total_rep_count == 0) { rmm::device_uvector list_offsets_vec(num_parent_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); + thrust::fill( + rmm::exec_policy_nosync(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); auto list_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_parent_rows + 1, list_offsets_vec.release(), @@ -1527,8 +1526,8 @@ std::unique_ptr build_repeated_child_list_column( 0); std::unique_ptr child_col; if (elem_type_id == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_schema_idx, num_fields, stream, mr); + child_col = + make_empty_struct_column_with_schema(schema, child_schema_idx, num_fields, stream, mr); } else { child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); } @@ -1538,8 +1537,9 @@ std::unique_ptr build_repeated_child_list_column( rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); thrust::exclusive_scan( - rmm::exec_policy(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); - thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_parent_rows, 1, total_rep_count); + rmm::exec_policy_nosync(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); + thrust::fill_n( + rmm::exec_policy_nosync(stream), list_offs.data() + num_parent_rows, 1, total_rep_count); rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); scan_repeated_in_nested_kernel<<>>(message_data, @@ -1555,7 +1555,7 @@ std::unique_ptr build_repeated_child_list_column( d_error.data()); rmm::device_uvector d_rep_top_row_indices(total_rep_count, stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), d_rep_occs.begin(), d_rep_occs.end(), d_rep_top_row_indices.begin(), @@ -1617,7 +1617,7 @@ std::unique_ptr build_repeated_child_list_column( 0); rmm::device_uvector d_elem_has_invalid_enum(total_rep_count, stream, mr); - thrust::fill(rmm::exec_policy(stream), + thrust::fill(rmm::exec_policy_nosync(stream), d_elem_has_invalid_enum.begin(), d_elem_has_invalid_enum.end(), false); @@ -1690,7 +1690,6 @@ std::unique_ptr build_repeated_child_list_column( gc_indices, schema, num_fields, - schema_output_types, default_ints, default_floats, default_bools, diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh index 09fc15889c..053e06a34b 100644 --- a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -18,7 +18,12 @@ #include "protobuf/protobuf_types.cuh" +#include + +#include + #include +#include namespace spark_rapids_jni::protobuf::detail { @@ -217,7 +222,8 @@ __device__ __host__ inline size_t flat_index(size_t row, size_t width, size_t co __device__ inline bool checked_add_int32(int32_t lhs, int32_t rhs, int32_t& out) { auto const sum = static_cast(lhs) + rhs; - if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { return false; } out = static_cast(sum); diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index 2f426e0bf1..691defbfd3 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -18,6 +18,36 @@ #include "protobuf/protobuf_kernels.cuh" +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + namespace spark_rapids_jni::protobuf::detail { // ============================================================================ @@ -245,7 +275,6 @@ std::unique_ptr make_empty_list_column(std::unique_ptr std::unique_ptr make_empty_struct_column_with_schema( SchemaT const& schema, - std::vector const& schema_output_types, int parent_idx, int num_fields, rmm::cuda_stream_view stream, @@ -255,12 +284,11 @@ std::unique_ptr make_empty_struct_column_with_schema( std::vector> children; for (int child_idx : child_indices) { - auto child_type = schema_output_types[child_idx]; + auto child_type = cudf::data_type{schema[child_idx].output_type}; std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema( - schema, schema_output_types, child_idx, num_fields, stream, mr); + child_col = make_empty_struct_column_with_schema(schema, child_idx, num_fields, stream, mr); } else { child_col = make_empty_column_safe(child_type, stream, mr); } @@ -357,7 +385,7 @@ inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const if (top_row_indices == nullptr) { CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), "enum invalid-row propagation exceeded row buffer"); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), row_invalid.begin(), row_invalid.begin() + num_items, item_invalid.begin(), @@ -369,7 +397,7 @@ inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const } rmm::device_uvector invalid_rows(num_items, stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_items), invalid_rows.begin(), @@ -378,10 +406,11 @@ inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const }); auto valid_end = - thrust::remove(rmm::exec_policy(stream), invalid_rows.begin(), invalid_rows.end(), -1); - thrust::sort(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); - auto unique_end = thrust::unique(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); - thrust::for_each(rmm::exec_policy(stream), + thrust::remove(rmm::exec_policy_nosync(stream), invalid_rows.begin(), invalid_rows.end(), -1); + thrust::sort(rmm::exec_policy_nosync(stream), invalid_rows.begin(), valid_end); + auto unique_end = + thrust::unique(rmm::exec_policy_nosync(stream), invalid_rows.begin(), valid_end); + thrust::for_each(rmm::exec_policy_nosync(stream), invalid_rows.begin(), unique_end, [row_invalid = row_invalid.data()] __device__(int32_t row_idx) { @@ -410,7 +439,7 @@ inline void validate_enum_and_propagate_rows(rmm::device_uvector const& stream.value())); rmm::device_uvector item_invalid(num_items, stream, mr); - thrust::fill(rmm::exec_policy(stream), item_invalid.begin(), item_invalid.end(), false); + thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); validate_enum_values_kernel<<>>( values.data(), valid.data(), @@ -491,7 +520,6 @@ std::unique_ptr build_nested_struct_column( std::vector const& child_field_indices, std::vector const& schema, int num_fields, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -517,7 +545,6 @@ std::unique_ptr build_repeated_child_list_column( int child_schema_idx, std::vector const& schema, int num_fields, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -545,7 +572,6 @@ std::unique_ptr build_repeated_struct_column( int num_rows, std::vector const& h_device_schema, std::vector const& child_field_indices, - std::vector const& schema_output_types, std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, @@ -603,7 +629,7 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( } rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_rows), valid.data(), @@ -812,16 +838,15 @@ inline std::unique_ptr build_repeated_scalar_column( if (total_count == 0) { // All rows have count=0, but we still need to check input nulls rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy(stream), offsets.begin(), offsets.end(), 0); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, num_rows + 1, offsets.release(), rmm::device_buffer{}, 0); - auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) - ? cudf::type_id::UINT8 - : static_cast(field_desc.output_type_id); - auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + auto elem_type = + field_desc.output_type == cudf::type_id::LIST ? cudf::type_id::UINT8 : field_desc.output_type; + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); if (input_null_count > 0) { // Copy input null mask - only input nulls produce output nulls @@ -839,11 +864,14 @@ inline std::unique_ptr build_repeated_scalar_column( } rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan( - rmm::exec_policy(stream), d_field_counts.begin(), d_field_counts.end(), list_offs.begin(), 0); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); int32_t total_count_i32 = static_cast(total_count); - thrust::fill_n(rmm::exec_policy(stream), list_offs.data() + num_rows, 1, total_count_i32); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); rmm::device_uvector values(total_count, stream, mr); @@ -892,12 +920,11 @@ inline std::unique_ptr build_repeated_scalar_column( list_offs.release(), rmm::device_buffer{}, 0); - auto child_col = std::make_unique( - cudf::data_type{static_cast(field_desc.output_type_id)}, - total_count, - values.release(), - rmm::device_buffer{}, - 0); + auto child_col = std::make_unique(cudf::data_type{field_desc.output_type}, + total_count, + values.release(), + rmm::device_buffer{}, + 0); // Only rows where INPUT is null should produce null output // Rows with valid input but count=0 should produce empty array [] diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index ea7d42b6d9..bf34ec0c67 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "protobuf/protobuf_device_helpers.cuh" +#include "protobuf/protobuf_kernels.cuh" namespace spark_rapids_jni::protobuf::detail { diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index 622f2e22a9..1aaca16ca3 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -18,6 +18,11 @@ #include "protobuf/protobuf_device_helpers.cuh" +#include +#include + +#include + namespace spark_rapids_jni::protobuf::detail { // ============================================================================ @@ -142,7 +147,7 @@ __global__ void extract_varint_kernel(uint8_t const* message_data, // For BOOL8 (uint8_t), protobuf spec says any non-zero varint is true. // A raw static_cast would silently truncate values >= 256 to 0. auto const write_value = [](OutputType* dst, uint64_t val) { - if constexpr (std::is_same_v) { + if constexpr (cuda::std::is_same_v) { *dst = static_cast(val != 0 ? 1 : 0); } else { *dst = static_cast(val); @@ -260,7 +265,7 @@ __global__ void extract_varint_batched_kernel(uint8_t const* message_data, auto* out = static_cast(desc.output); auto const write_value = [](OutputType* dst, uint64_t val) { - if constexpr (std::is_same_v) { + if constexpr (cuda::std::is_same_v) { *dst = static_cast(val != 0 ? 1 : 0); } else { *dst = static_cast(val); @@ -314,7 +319,7 @@ __global__ void extract_fixed_batched_kernel(uint8_t const* message_data, if (loc.offset < 0) { if (desc.has_default) { - if constexpr (std::is_integral_v) { + if constexpr (cuda::std::is_integral_v) { out[row] = static_cast(desc.default_int); } else { out[row] = static_cast(desc.default_float); diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh index c41adb1485..35d51966f1 100644 --- a/src/main/cpp/src/protobuf/protobuf_types.cuh +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -18,39 +18,6 @@ #include "protobuf/protobuf.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - namespace spark_rapids_jni::protobuf::detail { // Protobuf varint encoding uses at most 10 bytes to represent a 64-bit value. @@ -134,7 +101,7 @@ struct device_nested_field_descriptor { int parent_idx; int depth; int wire_type; - int output_type_id; + cudf::type_id output_type; int encoding; bool is_repeated; bool is_required; @@ -148,7 +115,7 @@ struct device_nested_field_descriptor { parent_idx(src.parent_idx), depth(src.depth), wire_type(static_cast(src.wire_type)), - output_type_id(static_cast(src.output_type)), + output_type(src.output_type), encoding(static_cast(src.encoding)), is_repeated(src.is_repeated), is_required(src.is_required), From 35b1cb8638a4288191b22ce3db8298ce6aac9fca Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Thu, 19 Mar 2026 16:39:26 +0800 Subject: [PATCH 096/107] nghia style self-check Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf/protobuf.cu | 24 ++---- src/main/cpp/src/protobuf/protobuf.hpp | 1 + .../cpp/src/protobuf/protobuf_builders.cu | 17 ++--- .../src/protobuf/protobuf_device_helpers.cuh | 53 ++++--------- .../src/protobuf/protobuf_host_helpers.hpp | 76 ++++--------------- src/main/cpp/src/protobuf/protobuf_kernels.cu | 75 +++++++----------- .../cpp/src/protobuf/protobuf_kernels.cuh | 6 +- src/main/cpp/src/protobuf/protobuf_types.cuh | 6 +- 8 files changed, 80 insertions(+), 178 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index bdc6466a5c..63a655eed6 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -20,10 +20,10 @@ #include #include -using namespace spark_rapids_jni::protobuf::detail; - namespace spark_rapids_jni::protobuf { +using namespace detail; + namespace { void propagate_nulls_to_descendants(cudf::column& col, @@ -707,22 +707,10 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& LAUNCH_VARINT_BATCH(4, uint8_t, false); LAUNCH_VARINT_BATCH(5, int32_t, true); LAUNCH_VARINT_BATCH(6, int64_t, true); - LAUNCH_FIXED_BATCH(7, - float, - spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)); - LAUNCH_FIXED_BATCH(8, - double, - spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT)); - LAUNCH_FIXED_BATCH(9, - int32_t, - spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)); - LAUNCH_FIXED_BATCH(10, - int64_t, - spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT)); + LAUNCH_FIXED_BATCH(7, float, wire_type_value(proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH(8, double, wire_type_value(proto_wire_type::I64BIT)); + LAUNCH_FIXED_BATCH(9, int32_t, wire_type_value(proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH(10, int64_t, wire_type_value(proto_wire_type::I64BIT)); #undef LAUNCH_VARINT_BATCH #undef LAUNCH_FIXED_BATCH diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp index 144fd35817..b86a5960d3 100644 --- a/src/main/cpp/src/protobuf/protobuf.hpp +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 19dda8a0ec..7888696adc 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -26,7 +26,7 @@ namespace spark_rapids_jni::protobuf::detail { * When as_bytes=false, builds a STRING column. When as_bytes=true, builds LIST. * Uses GPU kernels for parallel extraction (critical performance fix!). */ -inline std::unique_ptr build_repeated_msg_child_varlen_column( +std::unique_ptr build_repeated_msg_child_varlen_column( uint8_t const* message_data, rmm::device_uvector const& d_msg_row_offsets, rmm::device_uvector const& d_msg_locs, @@ -246,7 +246,7 @@ struct enum_string_lookup_tables { rmm::device_uvector d_name_chars; }; -inline enum_string_lookup_tables make_enum_string_lookup_tables( +enum_string_lookup_tables make_enum_string_lookup_tables( std::vector const& valid_enums, std::vector> const& enum_name_bytes, rmm::cuda_stream_view stream, @@ -296,7 +296,7 @@ inline enum_string_lookup_tables make_enum_string_lookup_tables( return {std::move(d_valid_enums), std::move(d_name_offsets), std::move(d_name_chars)}; } -inline std::unique_ptr build_enum_string_values_column( +std::unique_ptr build_enum_string_values_column( rmm::device_uvector& enum_values, rmm::device_uvector& valid, enum_string_lookup_tables const& lookup, @@ -377,7 +377,7 @@ std::unique_ptr build_enum_string_column( return build_enum_string_values_column(enum_values, valid, lookup, num_rows, stream, mr); } -inline std::unique_ptr build_repeated_msg_child_enum_string_column( +std::unique_ptr build_repeated_msg_child_enum_string_column( uint8_t const* message_data, rmm::device_uvector const& d_msg_row_offsets, rmm::device_uvector const& d_msg_locs, @@ -943,8 +943,7 @@ std::unique_ptr build_repeated_struct_column( break; } case cudf::type_id::STRING: { - if (enc == spark_rapids_jni::protobuf::encoding_value( - spark_rapids_jni::protobuf::proto_encoding::ENUM_STRING)) { + if (enc == encoding_value(proto_encoding::ENUM_STRING)) { if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && @@ -1251,8 +1250,7 @@ std::unique_ptr build_nested_struct_column( break; } case cudf::type_id::STRING: { - if (enc == spark_rapids_jni::protobuf::encoding_value( - spark_rapids_jni::protobuf::proto_encoding::ENUM_STRING)) { + if (enc == encoding_value(proto_encoding::ENUM_STRING)) { rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; @@ -1596,8 +1594,7 @@ std::unique_ptr build_repeated_child_list_column( propagate_invalid_rows); } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { if (elem_type_id == cudf::type_id::STRING && - schema[child_schema_idx].encoding == - spark_rapids_jni::protobuf::proto_encoding::ENUM_STRING) { + schema[child_schema_idx].encoding == proto_encoding::ENUM_STRING) { if (child_schema_idx < static_cast(enum_valid_values.size()) && child_schema_idx < static_cast(enum_names.size()) && !enum_valid_values[child_schema_idx].empty() && diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh index 053e06a34b..fe020bd825 100644 --- a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -73,8 +73,7 @@ inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stre __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { switch (wt) { - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::VARINT): { + case wire_type_value(proto_wire_type::VARINT): { // Need to scan to find the end of varint int count = 0; while (cur < end && count < MAX_VARINT_BYTES) { @@ -83,18 +82,15 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; // Invalid varint } - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT): + case wire_type_value(proto_wire_type::I64BIT): // Check if there's enough data for 8 bytes if (end - cur < 8) return -1; return 8; - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT): + case wire_type_value(proto_wire_type::I32BIT): // Check if there's enough data for 4 bytes if (end - cur < 4) return -1; return 4; - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN): { + case wire_type_value(proto_wire_type::LEN): { uint64_t len; int n; if (!read_varint(cur, end, len, n)) return -1; @@ -102,8 +98,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return -1; return n + static_cast(len); } - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::SGROUP): { + case wire_type_value(proto_wire_type::SGROUP): { auto const* start = cur; int depth = 1; while (cur < end && depth > 0) { @@ -113,30 +108,23 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con cur += key_bytes; int inner_wt = static_cast(key & 0x7); - if (inner_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { + if (inner_wt == wire_type_value(proto_wire_type::EGROUP)) { --depth; if (depth == 0) { return static_cast(cur - start); } - } else if (inner_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::SGROUP)) { + } else if (inner_wt == wire_type_value(proto_wire_type::SGROUP)) { if (++depth > 32) return -1; } else { int inner_size = -1; switch (inner_wt) { - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::VARINT): { + case wire_type_value(proto_wire_type::VARINT): { uint64_t dummy; int vbytes; if (!read_varint(cur, end, dummy, vbytes)) return -1; inner_size = vbytes; break; } - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT): - inner_size = 8; - break; - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN): { + case wire_type_value(proto_wire_type::I64BIT): inner_size = 8; break; + case wire_type_value(proto_wire_type::LEN): { uint64_t len; int len_bytes; if (!read_varint(cur, end, len, len_bytes)) return -1; @@ -144,10 +132,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con inner_size = len_bytes + static_cast(len); break; } - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT): - inner_size = 4; - break; + case wire_type_value(proto_wire_type::I32BIT): inner_size = 4; break; default: return -1; } if (inner_size < 0 || cur + inner_size > end) return -1; @@ -156,9 +141,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con } return -1; } - case spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::EGROUP): - return 0; + case wire_type_value(proto_wire_type::EGROUP): return 0; default: return -1; } } @@ -169,13 +152,10 @@ __device__ inline bool skip_field(uint8_t const* cur, uint8_t const*& out_cur) { // A bare end-group is only valid while a start-group payload is being parsed recursively inside - // get_wire_type_size(spark_rapids_jni::protobuf::wire_type_value(spark_rapids_jni::protobuf::proto_wire_type::SGROUP)). + // get_wire_type_size(wire_type_value(proto_wire_type::SGROUP)). // The scan/count kernels should never accept it as a standalone field because Spark CPU treats // unmatched end-groups as malformed protobuf. - if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::EGROUP)) { - return false; - } + if (wt == wire_type_value(proto_wire_type::EGROUP)) { return false; } int size = get_wire_type_size(wt, cur, end); if (size < 0) return false; @@ -192,8 +172,7 @@ __device__ inline bool skip_field(uint8_t const* cur, __device__ inline bool get_field_data_location( uint8_t const* cur, uint8_t const* end, int wt, int32_t& data_offset, int32_t& data_length) { - if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + if (wt == wire_type_value(proto_wire_type::LEN)) { // For length-delimited, read the length prefix uint64_t len; int len_bytes; @@ -261,7 +240,7 @@ __device__ inline bool decode_tag(uint8_t const*& cur, cur += key_bytes; uint64_t fn = key >> 3; - if (fn == 0 || fn > static_cast(spark_rapids_jni::protobuf::MAX_FIELD_NUMBER)) { + if (fn == 0 || fn > static_cast(MAX_FIELD_NUMBER)) { set_error_once(error_flag, ERR_FIELD_NUMBER); return false; } diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index 691defbfd3..bb426c5f3b 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -137,8 +137,7 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, int* error_ptr, rmm::cuda_stream_view stream) { - if (enable_zigzag && encoding == spark_rapids_jni::protobuf::encoding_value( - spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)) { + if (enable_zigzag && encoding == encoding_value(proto_encoding::ZIGZAG)) { extract_varint_kernel <<>>(message_data, loc_provider, @@ -148,13 +147,9 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, error_ptr, has_default, default_value); - } else if (encoding == spark_rapids_jni::protobuf::encoding_value( - spark_rapids_jni::protobuf::proto_encoding::FIXED)) { + } else if (encoding == encoding_value(proto_encoding::FIXED)) { if constexpr (sizeof(T) == 4) { - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_rows, @@ -165,10 +160,7 @@ inline void extract_integer_into_buffers(uint8_t const* message_data, static_cast(default_value)); } else { static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_rows, @@ -347,31 +339,6 @@ inline void maybe_check_required_fields(field_location const* locations, error_flag); } -__global__ void validate_enum_values_kernel(int32_t const* values, - bool* valid, - bool* row_has_invalid_enum, - int32_t const* valid_enum_values, - int num_valid_values, - int num_rows); - -__global__ void compute_enum_string_lengths_kernel(int32_t const* values, - bool const* valid, - int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, - int num_valid_values, - int32_t* lengths, - int num_rows); - -__global__ void copy_enum_string_chars_kernel(int32_t const* values, - bool const* valid, - int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, - uint8_t const* enum_name_chars, - int num_valid_values, - int32_t const* output_offsets, - char* out_chars, - int num_rows); - inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, rmm::device_uvector& row_invalid, int num_items, @@ -776,10 +743,7 @@ inline std::unique_ptr extract_typed_column( dt, num_items, [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_items, @@ -798,10 +762,7 @@ inline std::unique_ptr extract_typed_column( dt, num_items, [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel + extract_fixed_kernel <<>>(message_data, loc_provider, num_items, @@ -879,31 +840,24 @@ inline std::unique_ptr build_repeated_scalar_column( auto const blocks = static_cast((total_count + threads - 1u) / threads); int encoding = field_desc.encoding; - bool zigzag = (encoding == spark_rapids_jni::protobuf::encoding_value( - spark_rapids_jni::protobuf::proto_encoding::ZIGZAG)); + bool zigzag = (encoding == encoding_value(proto_encoding::ZIGZAG)); // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) // For integer types, use fixed kernel only if encoding is - // spark_rapids_jni::protobuf::encoding_value(spark_rapids_jni::protobuf::proto_encoding::FIXED) + // encoding_value(proto_encoding::FIXED) constexpr bool is_floating_point = std::is_same_v || std::is_same_v; - bool use_fixed_kernel = - is_floating_point || (encoding == spark_rapids_jni::protobuf::encoding_value( - spark_rapids_jni::protobuf::proto_encoding::FIXED)); + bool use_fixed_kernel = is_floating_point || (encoding == encoding_value(proto_encoding::FIXED)); RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; if (use_fixed_kernel) { if constexpr (sizeof(T) == 4) { - extract_fixed_kernel<<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { - extract_fixed_kernel<<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } } else if (zigzag) { extract_varint_kernel diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index bf34ec0c67..6b406f695b 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -99,8 +99,7 @@ __global__ void scan_all_fields_kernel( // Record the location (relative to message start) int data_offset = static_cast(cur - bytes - start); - if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + if (wt == wire_type_value(proto_wire_type::LEN)) { // For length-delimited, record offset after length prefix and the data length uint64_t len; int len_bytes; @@ -166,10 +165,8 @@ __device__ bool count_repeated_element(uint8_t const* cur, repeated_field_info& info, int* error_flag) { - bool is_packed = (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN) && - expected_wt != spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)); + bool is_packed = (wt == wire_type_value(proto_wire_type::LEN) && + expected_wt != wire_type_value(proto_wire_type::LEN)); if (!is_packed && wt != expected_wt) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -191,8 +188,7 @@ __device__ bool count_repeated_element(uint8_t const* cur, uint8_t const* packed_end = packed_start + packed_len; int count = 0; - if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::VARINT)) { + if (expected_wt == wire_type_value(proto_wire_type::VARINT)) { uint8_t const* p = packed_start; while (p < packed_end) { uint64_t dummy; @@ -204,15 +200,13 @@ __device__ bool count_repeated_element(uint8_t const* cur, p += vbytes; count++; } - } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + } else if (expected_wt == wire_type_value(proto_wire_type::I32BIT)) { if ((packed_len % 4) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; } count = static_cast(packed_len / 4); - } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT)) { + } else if (expected_wt == wire_type_value(proto_wire_type::I64BIT)) { if ((packed_len % 8) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -251,10 +245,8 @@ __device__ bool scan_repeated_element(uint8_t const* cur, int write_end, int* error_flag) { - bool is_packed = (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN) && - expected_wt != spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)); + bool is_packed = (wt == wire_type_value(proto_wire_type::LEN) && + expected_wt != wire_type_value(proto_wire_type::LEN)); if (!is_packed && wt != expected_wt) { set_error_once(error_flag, ERR_WIRE_TYPE); @@ -275,8 +267,7 @@ __device__ bool scan_repeated_element(uint8_t const* cur, } uint8_t const* packed_end = packed_start + packed_len; - if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::VARINT)) { + if (expected_wt == wire_type_value(proto_wire_type::VARINT)) { uint8_t const* p = packed_start; while (p < packed_end) { int32_t elem_offset = static_cast(p - msg_base); @@ -294,8 +285,7 @@ __device__ bool scan_repeated_element(uint8_t const* cur, write_idx++; p += vbytes; } - } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + } else if (expected_wt == wire_type_value(proto_wire_type::I32BIT)) { if ((packed_len % 4) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -308,8 +298,7 @@ __device__ bool scan_repeated_element(uint8_t const* cur, occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 4}; write_idx++; } - } else if (expected_wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT)) { + } else if (expected_wt == wire_type_value(proto_wire_type::I64BIT)) { if ((packed_len % 8) != 0) { set_error_once(error_flag, ERR_FIXED_LEN); return false; @@ -441,8 +430,7 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in // Check nested message fields at this depth auto handle_nested = [&](int i) { - if (wt != spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + if (wt != wire_type_value(proto_wire_type::LEN)) { set_error_once(error_flag, ERR_WIRE_TYPE); return false; } @@ -458,8 +446,8 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in return false; } auto const rel_offset64 = static_cast(cur - bytes - start); - if (rel_offset64 < std::numeric_limits::min() || - rel_offset64 > std::numeric_limits::max()) { + if (rel_offset64 < cuda::std::numeric_limits::min() || + rel_offset64 > cuda::std::numeric_limits::max()) { set_error_once(error_flag, ERR_OVERFLOW); return false; } @@ -546,10 +534,8 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co auto try_scan = [&](int f) -> bool { int target_wt = scan_descs[f].wire_type; - bool is_packed = (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN) && - target_wt != spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)); + bool is_packed = (wt == wire_type_value(proto_wire_type::LEN) && + target_wt != wire_type_value(proto_wire_type::LEN)); if (is_packed || wt == target_wt) { return scan_repeated_element(cur, msg_end, @@ -662,8 +648,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, int data_offset = static_cast(cur - nested_start); - if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + if (wt == wire_type_value(proto_wire_type::LEN)) { uint64_t len; int len_bytes; if (!read_varint(cur, nested_end, len, len_bytes)) { @@ -766,8 +751,7 @@ __global__ void scan_repeated_message_children_kernel( } else { int data_offset = static_cast(cur - msg_start); - if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::LEN)) { + if (wt == wire_type_value(proto_wire_type::LEN)) { uint64_t len; int len_bytes; if (!read_varint(cur, msg_end, len, len_bytes)) { @@ -791,8 +775,7 @@ __global__ void scan_repeated_message_children_kernel( } else { // For varint/fixed types, store offset and estimated length int32_t data_length = 0; - if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::VARINT)) { + if (wt == wire_type_value(proto_wire_type::VARINT)) { uint64_t dummy; int vbytes; if (!read_varint(cur, msg_end, dummy, vbytes)) { @@ -800,15 +783,13 @@ __global__ void scan_repeated_message_children_kernel( return; } data_length = vbytes; - } else if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + } else if (wt == wire_type_value(proto_wire_type::I32BIT)) { if (msg_end - cur < 4) { set_error_once(error_flag, ERR_FIXED_LEN); return; } data_length = 4; - } else if (wt == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I64BIT)) { + } else if (wt == wire_type_value(proto_wire_type::I64BIT)) { if (msg_end - cur < 8) { set_error_once(error_flag, ERR_FIXED_LEN); return; @@ -1014,8 +995,8 @@ __global__ void compute_nested_struct_locations_kernel( static_cast(num_child_fields), static_cast(child_idx))]; auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; - if (sum < std::numeric_limits::min() || - sum > std::numeric_limits::max()) { + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { nested_locs[idx] = {-1, 0}; nested_row_offsets[idx] = 0; set_error_once(error_flag, ERR_OVERFLOW); @@ -1049,7 +1030,8 @@ __global__ void compute_grandchild_parent_locations_kernel( if (parent_loc.offset >= 0 && child_loc.offset >= 0) { // Absolute offset = parent offset + child's relative offset auto sum = static_cast(parent_loc.offset) + child_loc.offset; - if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { gc_parent_abs[row] = {-1, 0}; set_error_once(error_flag, ERR_OVERFLOW); return; @@ -1088,7 +1070,8 @@ __global__ void compute_virtual_parents_for_nested_repeated_kernel( // struct with all-null children (not a null struct). if (ploc.offset >= 0) { auto sum = static_cast(ploc.offset) + occ.offset; - if (sum < std::numeric_limits::min() || sum > std::numeric_limits::max()) { + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { virtual_parent_locs[idx] = {-1, 0}; set_error_once(error_flag, ERR_OVERFLOW); return; @@ -1117,8 +1100,8 @@ __global__ void compute_msg_locations_from_occurrences_kernel( auto const& occ = occurrences[idx]; auto row_offset = static_cast(list_offsets[occ.row_idx]) - base_offset; - if (row_offset < std::numeric_limits::min() || - row_offset > std::numeric_limits::max()) { + if (row_offset < cuda::std::numeric_limits::min() || + row_offset > cuda::std::numeric_limits::max()) { msg_row_offsets[idx] = 0; msg_locs[idx] = {-1, 0}; set_error_once(error_flag, ERR_OVERFLOW); diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index 1aaca16ca3..66869a4fe3 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -209,8 +209,7 @@ __global__ void extract_fixed_kernel(uint8_t const* message_data, uint8_t const* cur = message_data + data_offset; OutputType value; - if constexpr (WT == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); if (valid) valid[idx] = false; @@ -335,8 +334,7 @@ __global__ void extract_fixed_batched_kernel(uint8_t const* message_data, uint8_t const* cur = message_data + data_offset; OutputType value; - if constexpr (WT == spark_rapids_jni::protobuf::wire_type_value( - spark_rapids_jni::protobuf::proto_wire_type::I32BIT)) { + if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); desc.valid[row] = false; diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh index 35d51966f1..e8ffa4a292 100644 --- a/src/main/cpp/src/protobuf/protobuf_types.cuh +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -109,8 +109,10 @@ struct device_nested_field_descriptor { device_nested_field_descriptor() = default; - explicit device_nested_field_descriptor( - spark_rapids_jni::protobuf::nested_field_descriptor const& src) + // Wire type and encoding are stored as int (not typed enums) because CUDA device code + // historically had limited constexpr enum support, and the kernel comparison sites use + // int-typed wire_type_value()/encoding_value() helpers throughout. + explicit device_nested_field_descriptor(nested_field_descriptor const& src) : field_number(src.field_number), parent_idx(src.parent_idx), depth(src.depth), From 82bf02ab2ac61598363e905c2a50fe732c3ade03 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Fri, 20 Mar 2026 21:16:23 +0800 Subject: [PATCH 097/107] backport suggestions Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf/protobuf.cu | 73 +- .../cpp/src/protobuf/protobuf_builders.cu | 272 ++++---- .../src/protobuf/protobuf_device_helpers.cuh | 8 +- .../src/protobuf/protobuf_host_helpers.hpp | 318 +++++---- src/main/cpp/src/protobuf/protobuf_kernels.cu | 629 ++++++++++++++++-- .../cpp/src/protobuf/protobuf_kernels.cuh | 257 ++----- 6 files changed, 972 insertions(+), 585 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 63a655eed6..9831a6f71c 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -16,6 +16,8 @@ #include "protobuf/protobuf_host_helpers.hpp" +#include + #include #include #include @@ -477,22 +479,23 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.value())); } - count_repeated_fields_kernel<<>>( - *d_in, - d_schema.data(), - num_fields, - 0, - d_repeated_info.data(), - num_repeated, - d_repeated_indices.data(), - d_nested_locations.data(), - num_nested, - d_nested_indices.data(), - d_error.data(), - d_fn_to_rep.data(), - static_cast(d_fn_to_rep.size()), - d_fn_to_nested.data(), - static_cast(d_fn_to_nested.size())); + launch_count_repeated_fields(*d_in, + d_schema.data(), + num_fields, + 0, + d_repeated_info.data(), + num_repeated, + d_repeated_indices.data(), + d_nested_locations.data(), + num_nested, + d_nested_indices.data(), + d_error.data(), + d_fn_to_rep.data(), + static_cast(d_fn_to_rep.size()), + d_fn_to_nested.data(), + static_cast(d_fn_to_nested.size()), + num_rows, + stream); } // Store decoded columns by schema index for ordered assembly at the end. @@ -528,15 +531,16 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& stream.value())); } - scan_all_fields_kernel<<>>( - *d_in, - d_field_descs.data(), - num_scalar, - h_field_lookup.empty() ? nullptr : d_field_lookup.data(), - static_cast(h_field_lookup.size()), - d_locations.data(), - d_error.data(), - track_permissive_null_rows ? d_row_force_null.data() : nullptr); + launch_scan_all_fields(*d_in, + d_field_descs.data(), + num_scalar, + h_field_lookup.empty() ? nullptr : d_field_lookup.data(), + static_cast(h_field_lookup.size()), + d_locations.data(), + d_error.data(), + track_permissive_null_rows ? d_row_force_null.data() : nullptr, + num_rows, + stream); // Required-field validation applies to all scalar leaves, not just top-level numerics. maybe_check_required_fields(d_locations.data(), @@ -964,13 +968,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& fn_to_scan_size = static_cast(h_fn_to_scan.size()); } - scan_all_repeated_occurrences_kernel<<>>( - *d_in, - d_scan_descs.data(), - static_cast(h_scan_descs.size()), - d_error.data(), - fn_to_scan_size > 0 ? d_fn_to_scan.data() : nullptr, - fn_to_scan_size); + launch_scan_all_repeated_occurrences(*d_in, + d_scan_descs.data(), + static_cast(h_scan_descs.size()), + d_error.data(), + fn_to_scan_size > 0 ? d_fn_to_scan.data() : nullptr, + fn_to_scan_size, + num_rows, + stream); } // Phase C: Build columns per field. @@ -1249,8 +1254,8 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Extract parent locations for this nested field directly on GPU rmm::device_uvector d_parent_locs(num_rows, stream, mr); - extract_strided_locations_kernel<<>>( - d_nested_locations.data(), ni, num_nested, d_parent_locs.data(), num_rows); + launch_extract_strided_locations( + d_nested_locations.data(), ni, num_nested, d_parent_locs.data(), num_rows, stream); column_map[parent_schema_idx] = build_nested_struct_column(message_data, message_data_size, diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 7888696adc..0d63e7cddb 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -304,34 +304,31 @@ std::unique_ptr build_enum_string_values_column( rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1u) / threads); - rmm::device_uvector lengths(num_rows, stream, mr); - compute_enum_string_lengths_kernel<<>>( - enum_values.data(), - valid.data(), - lookup.d_valid_enums.data(), - lookup.d_name_offsets.data(), - static_cast(lookup.d_valid_enums.size()), - lengths.data(), - num_rows); + launch_compute_enum_string_lengths(enum_values.data(), + valid.data(), + lookup.d_valid_enums.data(), + lookup.d_name_offsets.data(), + static_cast(lookup.d_valid_enums.size()), + lengths.data(), + num_rows, + stream); auto [offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - copy_enum_string_chars_kernel<<>>( - enum_values.data(), - valid.data(), - lookup.d_valid_enums.data(), - lookup.d_name_offsets.data(), - lookup.d_name_chars.data(), - static_cast(lookup.d_valid_enums.size()), - offsets_col->view().data(), - chars.data(), - num_rows); + launch_copy_enum_string_chars(enum_values.data(), + valid.data(), + lookup.d_valid_enums.data(), + lookup.d_name_offsets.data(), + lookup.d_name_chars.data(), + static_cast(lookup.d_valid_enums.size()), + offsets_col->view().data(), + chars.data(), + num_rows, + stream); } auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); @@ -351,22 +348,20 @@ std::unique_ptr build_enum_string_column( int32_t const* top_row_indices, bool propagate_invalid_rows) { - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1u) / threads); - auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); rmm::device_uvector d_item_has_invalid_enum(num_rows, stream, mr); thrust::fill(rmm::exec_policy_nosync(stream), d_item_has_invalid_enum.begin(), d_item_has_invalid_enum.end(), false); - validate_enum_values_kernel<<>>( - enum_values.data(), - valid.data(), - d_item_has_invalid_enum.data(), - lookup.d_valid_enums.data(), - static_cast(valid_enums.size()), - num_rows); + launch_validate_enum_values(enum_values.data(), + valid.data(), + d_item_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows, + stream); propagate_invalid_enum_flags_to_rows(d_item_has_invalid_enum, d_row_force_null, num_rows, @@ -421,13 +416,13 @@ std::unique_ptr build_repeated_msg_child_enum_string_column( d_elem_has_invalid_enum.begin(), d_elem_has_invalid_enum.end(), false); - validate_enum_values_kernel<<>>( - enum_values.data(), - valid.data(), - d_elem_has_invalid_enum.data(), - lookup.d_valid_enums.data(), - static_cast(valid_enums.size()), - total_count); + launch_validate_enum_values(enum_values.data(), + valid.data(), + d_elem_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + total_count, + stream); propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, d_row_force_null, total_count, @@ -480,13 +475,13 @@ std::unique_ptr build_repeated_enum_string_column( d_elem_has_invalid_enum.begin(), d_elem_has_invalid_enum.end(), false); - validate_enum_values_kernel<<>>( - enum_ints.data(), - elem_valid.data(), - d_elem_has_invalid_enum.data(), - lookup.d_valid_enums.data(), - static_cast(valid_enums.size()), - total_count); + launch_validate_enum_values(enum_ints.data(), + elem_valid.data(), + d_elem_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + total_count, + stream); rmm::device_uvector d_top_row_indices(total_count, stream, mr); thrust::transform(rmm::exec_policy_nosync(stream), @@ -807,18 +802,14 @@ std::unique_ptr build_repeated_struct_column( // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) rmm::device_uvector d_msg_locs(total_count, stream, mr); rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); - { - auto const occ_threads = THREADS_PER_BLOCK; - auto const occ_blocks = static_cast((total_count + occ_threads - 1u) / occ_threads); - compute_msg_locations_from_occurrences_kernel<<>>( - d_occurrences.data(), - list_offsets, - base_offset, - d_msg_locs.data(), - d_msg_row_offsets.data(), - total_count, - d_error_top.data()); - } + launch_compute_msg_locations_from_occurrences(d_occurrences.data(), + list_offsets, + base_offset, + d_msg_locs.data(), + d_msg_row_offsets.data(), + total_count, + d_error_top.data(), + stream); rmm::device_uvector d_top_row_indices(total_count, stream, mr); thrust::transform(rmm::exec_policy_nosync(stream), d_occurrences.data(), @@ -836,18 +827,18 @@ std::unique_ptr build_repeated_struct_column( // Use a custom kernel to scan child fields within message occurrences // This is similar to scan_nested_message_fields_kernel but operates on occurrences - scan_repeated_message_children_kernel<<>>( - message_data, - message_data_size, - d_msg_row_offsets.data(), - d_msg_locs.data(), - total_count, - d_child_descs.data(), - num_child_fields, - d_child_locs.data(), - d_error.data(), - h_child_lookup.empty() ? nullptr : d_child_lookup.data(), - static_cast(d_child_lookup.size())); + launch_scan_repeated_message_children(message_data, + message_data_size, + d_msg_row_offsets.data(), + d_msg_locs.data(), + total_count, + d_child_descs.data(), + num_child_fields, + d_child_locs.data(), + d_error.data(), + h_child_lookup.empty() ? nullptr : d_child_lookup.data(), + static_cast(d_child_lookup.size()), + stream); // Enforce proto2 required semantics for fields inside each repeated message occurrence. maybe_check_required_fields(d_child_locs.data(), @@ -1015,16 +1006,16 @@ std::unique_ptr build_repeated_struct_column( // Compute virtual parent locations for each occurrence's nested struct child rmm::device_uvector d_nested_locs(total_count, stream, mr); rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); - compute_nested_struct_locations_kernel<<>>( - d_child_locs.data(), - d_msg_locs.data(), - d_msg_row_offsets.data(), - ci, - num_child_fields, - d_nested_locs.data(), - d_nested_row_offsets.data(), - total_count, - d_error_top.data()); + launch_compute_nested_struct_locations(d_child_locs.data(), + d_msg_locs.data(), + d_msg_row_offsets.data(), + ci, + num_child_fields, + d_nested_locs.data(), + d_nested_row_offsets.data(), + total_count, + d_error_top.data(), + stream); struct_children.push_back(build_nested_struct_column(message_data, message_data_size, @@ -1151,17 +1142,17 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector d_child_locations( static_cast(num_rows) * num_child_fields, stream, mr); - scan_nested_message_fields_kernel<<>>( - message_data, - message_data_size, - list_offsets, - base_offset, - d_parent_locs.data(), - num_rows, - d_child_field_descs.data(), - num_child_fields, - d_child_locations.data(), - d_error.data()); + launch_scan_nested_message_fields(message_data, + message_data_size, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + d_child_field_descs.data(), + num_child_fields, + d_child_locations.data(), + d_error.data(), + stream); // Enforce proto2 required semantics for direct children of this nested message. maybe_check_required_fields(d_child_locations.data(), @@ -1381,14 +1372,14 @@ std::unique_ptr build_nested_struct_column( break; } rmm::device_uvector d_gc_parent(num_rows, stream, mr); - compute_grandchild_parent_locations_kernel<<>>( - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields, - d_gc_parent.data(), - num_rows, - d_error.data()); + launch_compute_grandchild_parent_locations(d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + d_gc_parent.data(), + num_rows, + d_error.data(), + stream); struct_children.push_back(build_nested_struct_column(message_data, message_data_size, list_offsets, @@ -1457,9 +1448,6 @@ std::unique_ptr build_repeated_child_list_column( int depth, bool propagate_invalid_rows) { - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_parent_rows + threads - 1u) / threads); - auto elem_type_id = schema[child_schema_idx].output_type; rmm::device_uvector d_rep_info(num_parent_rows, stream, mr); @@ -1491,18 +1479,19 @@ std::unique_ptr build_repeated_child_list_column( cudaMemcpyHostToDevice, stream.value())); - count_repeated_in_nested_kernel<<>>(message_data, - message_data_size, - row_offsets, - base_offset, - parent_locs, - num_parent_rows, - d_rep_schema.data(), - 1, - d_rep_info.data(), - 1, - d_rep_indices.data(), - d_error.data()); + launch_count_repeated_in_nested(message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + 1, + d_rep_info.data(), + 1, + d_rep_indices.data(), + d_error.data(), + stream); rmm::device_uvector d_rep_counts(num_parent_rows, stream, mr); thrust::transform(rmm::exec_policy_nosync(stream), @@ -1540,17 +1529,18 @@ std::unique_ptr build_repeated_child_list_column( rmm::exec_policy_nosync(stream), list_offs.data() + num_parent_rows, 1, total_rep_count); rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); - scan_repeated_in_nested_kernel<<>>(message_data, - message_data_size, - row_offsets, - base_offset, - parent_locs, - num_parent_rows, - d_rep_schema.data(), - list_offs.data(), - d_rep_indices.data(), - d_rep_occs.data(), - d_error.data()); + launch_scan_repeated_in_nested(message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + list_offs.data(), + d_rep_indices.data(), + d_rep_occs.data(), + d_error.data(), + stream); rmm::device_uvector d_rep_top_row_indices(total_rep_count, stream, mr); thrust::transform(rmm::exec_policy_nosync(stream), @@ -1618,13 +1608,13 @@ std::unique_ptr build_repeated_child_list_column( d_elem_has_invalid_enum.begin(), d_elem_has_invalid_enum.end(), false); - validate_enum_values_kernel<<>>( - enum_values.data(), - valid.data(), - d_elem_has_invalid_enum.data(), - lookup.d_valid_enums.data(), - static_cast(lookup.d_valid_enums.size()), - total_rep_count); + launch_validate_enum_values(enum_values.data(), + valid.data(), + d_elem_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(lookup.d_valid_enums.size()), + total_rep_count, + stream); propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, d_row_force_null, total_rep_count, @@ -1666,18 +1656,14 @@ std::unique_ptr build_repeated_child_list_column( } else { rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); - auto const rep_blk = (total_rep_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK; - compute_virtual_parents_for_nested_repeated_kernel<<>>( - d_rep_occs.data(), - row_offsets, - parent_locs, - d_virtual_row_offsets.data(), - d_virtual_parent_locs.data(), - total_rep_count, - d_error.data()); + launch_compute_virtual_parents_for_nested_repeated(d_rep_occs.data(), + row_offsets, + parent_locs, + d_virtual_row_offsets.data(), + d_virtual_parent_locs.data(), + total_rep_count, + d_error.data(), + stream); child_values = build_nested_struct_column(message_data, message_data_size, diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh index fe020bd825..9e01176eac 100644 --- a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -62,13 +62,7 @@ __device__ inline void set_error_once(int* error_flag, int error_code) ref.compare_exchange_strong(expected, error_code, cuda::memory_order_relaxed); } -__global__ void set_error_if_unset_kernel(int* error_flag, int error_code); - -inline void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) -{ - set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); - CUDF_CUDA_TRY(cudaPeekAtLastError()); -} +void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream); __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t const* end) { diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index bb426c5f3b..326b07be13 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -295,129 +295,203 @@ std::unique_ptr make_empty_struct_column_with_schema( return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); } -inline void maybe_check_required_fields(field_location const* locations, - std::vector const& field_indices, - std::vector const& schema, - int num_rows, - cudf::bitmask_type const* input_null_mask, - cudf::size_type input_offset, - field_location const* parent_locs, - bool* row_force_null, - int32_t const* top_row_indices, - int* error_flag, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - if (num_rows == 0 || field_indices.empty()) { return; } - - bool has_required = false; - std::vector h_is_required(field_indices.size()); - for (size_t i = 0; i < field_indices.size(); ++i) { - h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; - has_required |= (h_is_required[i] != 0); - } - if (!has_required) { return; } - - rmm::device_uvector d_is_required(field_indices.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), - h_is_required.data(), - h_is_required.size() * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - - auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); - check_required_fields_kernel<<>>( - locations, - d_is_required.data(), - static_cast(field_indices.size()), - num_rows, - input_null_mask, - input_offset, - parent_locs, - row_force_null, - top_row_indices, - error_flag); -} - -inline void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, - rmm::device_uvector& row_invalid, - int num_items, - int32_t const* top_row_indices, - bool propagate_to_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } - - if (top_row_indices == nullptr) { - CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), - "enum invalid-row propagation exceeded row buffer"); - thrust::transform(rmm::exec_policy_nosync(stream), - row_invalid.begin(), - row_invalid.begin() + num_items, - item_invalid.begin(), - row_invalid.begin(), - [] __device__(bool row_is_invalid, bool item_is_invalid) { - return row_is_invalid || item_is_invalid; - }); - return; - } - - rmm::device_uvector invalid_rows(num_items, stream, mr); - thrust::transform(rmm::exec_policy_nosync(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_items), - invalid_rows.begin(), - [item_invalid = item_invalid.data(), top_row_indices] __device__(int idx) { - return item_invalid[idx] ? top_row_indices[idx] : -1; - }); - - auto valid_end = - thrust::remove(rmm::exec_policy_nosync(stream), invalid_rows.begin(), invalid_rows.end(), -1); - thrust::sort(rmm::exec_policy_nosync(stream), invalid_rows.begin(), valid_end); - auto unique_end = - thrust::unique(rmm::exec_policy_nosync(stream), invalid_rows.begin(), valid_end); - thrust::for_each(rmm::exec_policy_nosync(stream), - invalid_rows.begin(), - unique_end, - [row_invalid = row_invalid.data()] __device__(int32_t row_idx) { - row_invalid[row_idx] = true; - }); -} +// ============================================================================ +// Host wrapper declarations for kernel launches +// ============================================================================ -inline void validate_enum_and_propagate_rows(rmm::device_uvector const& values, - rmm::device_uvector& valid, - std::vector const& valid_enums, - rmm::device_uvector& row_invalid, - int num_items, - int32_t const* top_row_indices, - bool propagate_to_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - if (num_items == 0 || valid_enums.empty()) { return; } - - auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); - - rmm::device_uvector item_invalid(num_items, stream, mr); - thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); - validate_enum_values_kernel<<>>( - values.data(), - valid.data(), - item_invalid.data(), - d_valid_enums.data(), - static_cast(valid_enums.size()), - num_items); - - propagate_invalid_enum_flags_to_rows( - item_invalid, row_invalid, num_items, top_row_indices, propagate_to_rows, stream, mr); -} +void launch_scan_all_fields(cudf::column_device_view const& d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag, + bool* row_has_invalid_data, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_count_repeated_fields(cudf::column_device_view const& d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_extract_strided_locations(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_scan_nested_message_fields(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_scan_repeated_message_children(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag, + int const* child_lookup, + int child_lookup_size, + rmm::cuda_stream_view stream); + +void launch_count_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_scan_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_nested_struct_locations(field_location const* child_locs, + field_location const* msg_locs, + cudf::size_type const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + cudf::size_type* nested_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_grandchild_parent_locations(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_msg_locations_from_occurrences(repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + cudf::size_type* msg_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_validate_enum_values(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_compute_enum_string_lengths(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_copy_enum_string_chars(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows, + rmm::cuda_stream_view stream); + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); // ============================================================================ // Forward declarations of builder/utility functions diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index 6b406f695b..174420dd7b 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -16,13 +16,29 @@ #include "protobuf/protobuf_kernels.cuh" +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + namespace spark_rapids_jni::protobuf::detail { +namespace { // ============================================================================ // Pass 1: Scan all fields kernel - records (offset, length) for each field // ============================================================================ -__global__ void set_error_if_unset_kernel(int* error_flag, int error_code) +CUDF_KERNEL void set_error_if_unset_kernel(int* error_flag, int error_code) { if (blockIdx.x == 0 && threadIdx.x == 0) { set_error_once(error_flag, error_code); } } @@ -40,7 +56,7 @@ __global__ void set_error_if_unset_kernel(int* error_flag, int error_code) * row-level invalidity buffer so the full struct row can be nulled to match Spark CPU semantics for * malformed messages. */ -__global__ void scan_all_fields_kernel( +CUDF_KERNEL void scan_all_fields_kernel( cudf::column_device_view const d_in, field_descriptor const* field_descs, // [num_fields] int num_fields, @@ -340,21 +356,21 @@ __device__ bool scan_repeated_element(uint8_t const* cur, * Optional lookup tables (fn_to_rep_idx, fn_to_nested_idx) provide O(1) field_number * to local index mapping. When nullptr, falls back to linear search. */ -__global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int num_fields, - int depth_level, - repeated_field_info* repeated_info, - int num_repeated_fields, - int const* repeated_field_indices, - field_location* nested_locations, - int num_nested_fields, - int const* nested_field_indices, - int* error_flag, - int const* fn_to_rep_idx, - int fn_to_rep_size, - int const* fn_to_nested_idx, - int fn_to_nested_size) +CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; @@ -493,12 +509,12 @@ __global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in * Combined occurrence scan: scans each message ONCE and writes occurrences for ALL * repeated fields simultaneously, scanning each message only once. */ -__global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, - repeated_field_scan_desc const* scan_descs, - int num_scan_fields, - int* error_flag, - int const* fn_to_desc_idx, - int fn_to_desc_size) +CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; @@ -594,16 +610,16 @@ __global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view co * scanning, previous benchmarking did not show a stable win from adding a host-built lookup table * here, so we keep the simpler implementation unless that trade-off changes. */ -__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - int num_parent_rows, - field_descriptor const* field_descs, - int num_fields, - field_location* output_locations, - int* error_flag) +CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_parent_rows) return; @@ -695,7 +711,7 @@ __global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, * Scan for child fields within repeated message occurrences. * Each occurrence is a protobuf message, and we need to find child field locations within it. */ -__global__ void scan_repeated_message_children_kernel( +CUDF_KERNEL void scan_repeated_message_children_kernel( uint8_t const* message_data, cudf::size_type message_data_size, cudf::size_type const* msg_row_offsets, // Row offset for each occurrence @@ -822,18 +838,18 @@ __global__ void scan_repeated_message_children_kernel( * the depth is implicitly fixed. Callers must pre-filter repeated_indices to include only * fields at the expected child depth. */ -__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - repeated_field_info* repeated_info, - int num_repeated, - int const* repeated_indices, - int* error_flag) +CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -902,17 +918,17 @@ __global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, * * Note: no depth-level check is performed; see count_repeated_in_nested_kernel comment. */ -__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int32_t const* occ_prefix_sums, - int const* repeated_indices, - repeated_occurrence* occurrences, - int* error_flag) +CUDF_KERNEL void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (row >= num_rows) return; @@ -977,7 +993,7 @@ __global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, * Replaces host-side loop that was copying data D->H, processing, then H->D. * This is a critical performance optimization. */ -__global__ void compute_nested_struct_locations_kernel( +CUDF_KERNEL void compute_nested_struct_locations_kernel( field_location const* child_locs, // Child field locations from parent scan field_location const* msg_locs, // Parent message locations cudf::size_type const* msg_row_offsets, // Parent message row offsets @@ -1010,7 +1026,7 @@ __global__ void compute_nested_struct_locations_kernel( * Computes: gc_parent_abs[i] = parent[i].offset + child[i * ncf + ci].offset * This replaces host-side loop with D->H->D copy pattern. */ -__global__ void compute_grandchild_parent_locations_kernel( +CUDF_KERNEL void compute_grandchild_parent_locations_kernel( field_location const* parent_locs, // Parent locations (row count) field_location const* child_locs, // Child locations (row * num_child_fields) int child_idx, // Which child field @@ -1048,7 +1064,7 @@ __global__ void compute_grandchild_parent_locations_kernel( * inside nested messages. Each occurrence becomes a virtual "row" so that * build_nested_struct_column can recursively process the children. */ -__global__ void compute_virtual_parents_for_nested_repeated_kernel( +CUDF_KERNEL void compute_virtual_parents_for_nested_repeated_kernel( repeated_occurrence const* occurrences, cudf::size_type const* row_list_offsets, // original binary input list offsets field_location const* parent_locations, // parent nested message locations @@ -1086,7 +1102,7 @@ __global__ void compute_virtual_parents_for_nested_repeated_kernel( * Kernel to compute message locations and row offsets from repeated occurrences. * Replaces host-side loop that processed occurrences. */ -__global__ void compute_msg_locations_from_occurrences_kernel( +CUDF_KERNEL void compute_msg_locations_from_occurrences_kernel( repeated_occurrence const* occurrences, // Repeated field occurrences cudf::size_type const* list_offsets, // List offsets for rows cudf::size_type base_offset, // Base offset to subtract @@ -1115,11 +1131,11 @@ __global__ void compute_msg_locations_from_occurrences_kernel( * Extract a single field's locations from a 2D strided array on the GPU. * Replaces a D2H + CPU loop + H2D pattern for nested message location extraction. */ -__global__ void extract_strided_locations_kernel(field_location const* nested_locations, - int field_idx, - int num_fields, - field_location* parent_locs, - int num_rows) +CUDF_KERNEL void extract_strided_locations_kernel(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows) { int row = blockIdx.x * blockDim.x + threadIdx.x; if (row >= num_rows) return; @@ -1135,7 +1151,7 @@ __global__ void extract_strided_locations_kernel(field_location const* nested_lo * Check if any required fields are missing (offset < 0) and set error flag. * This is called after the scan pass to validate required field constraints. */ -__global__ void check_required_fields_kernel( +CUDF_KERNEL void check_required_fields_kernel( field_location const* locations, // [num_rows * num_fields] uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) int num_fields, @@ -1185,7 +1201,7 @@ __global__ void check_required_fields_kernel( * @note Time complexity: O(log(num_valid_values)) per row. */ -__global__ void validate_enum_values_kernel( +CUDF_KERNEL void validate_enum_values_kernel( int32_t const* values, // [num_rows] extracted enum values bool* valid, // [num_rows] field validity flags (will be modified) bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) @@ -1232,7 +1248,7 @@ __global__ void validate_enum_values_kernel( * row_has_invalid_enum). */ -__global__ void compute_enum_string_lengths_kernel( +CUDF_KERNEL void compute_enum_string_lengths_kernel( int32_t const* values, // [num_rows] enum numeric values bool const* valid, // [num_rows] field validity int32_t const* valid_enum_values, // sorted enum numeric values @@ -1272,7 +1288,7 @@ __global__ void compute_enum_string_lengths_kernel( /** * Copy enum-as-string UTF-8 bytes into output chars buffer using precomputed row offsets. */ -__global__ void copy_enum_string_chars_kernel( +CUDF_KERNEL void copy_enum_string_chars_kernel( int32_t const* values, // [num_rows] enum numeric values bool const* valid, // [num_rows] field validity int32_t const* valid_enum_values, // sorted enum numeric values @@ -1309,4 +1325,477 @@ __global__ void copy_enum_string_chars_kernel( } } +} // anonymous namespace + +// ============================================================================ +// Host wrapper functions — callable from other translation units +// ============================================================================ + +void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) +{ + set_error_if_unset_kernel<<<1, 1, 0, stream.value()>>>(error_flag, error_code); + CUDF_CUDA_TRY(cudaPeekAtLastError()); +} + +void launch_scan_all_fields(cudf::column_device_view const& d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag, + bool* row_has_invalid_data, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_all_fields_kernel<<>>(d_in, + field_descs, + num_fields, + field_lookup, + field_lookup_size, + locations, + error_flag, + row_has_invalid_data); +} + +void launch_count_repeated_fields(cudf::column_device_view const& d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + count_repeated_fields_kernel<<>>( + d_in, + schema, + num_fields, + depth_level, + repeated_info, + num_repeated_fields, + repeated_field_indices, + nested_locations, + num_nested_fields, + nested_field_indices, + error_flag, + fn_to_rep_idx, + fn_to_rep_size, + fn_to_nested_idx, + fn_to_nested_size); +} + +void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_all_repeated_occurrences_kernel<<>>( + d_in, scan_descs, num_scan_fields, error_flag, fn_to_desc_idx, fn_to_desc_size); +} + +void launch_extract_strided_locations(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + extract_strided_locations_kernel<<>>( + nested_locations, field_idx, num_fields, parent_locs, num_rows); +} + +void launch_scan_nested_message_fields(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_parent_rows == 0) return; + auto const blocks = + static_cast((num_parent_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_nested_message_fields_kernel<<>>( + message_data, + message_data_size, + parent_row_offsets, + parent_base_offset, + parent_locations, + num_parent_rows, + field_descs, + num_fields, + output_locations, + error_flag); +} + +void launch_scan_repeated_message_children(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag, + int const* child_lookup, + int child_lookup_size, + rmm::cuda_stream_view stream) +{ + if (num_occurrences == 0) return; + auto const blocks = + static_cast((num_occurrences + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_repeated_message_children_kernel<<>>( + message_data, + message_data_size, + msg_row_offsets, + msg_locs, + num_occurrences, + child_descs, + num_child_fields, + child_locs, + error_flag, + child_lookup, + child_lookup_size); +} + +void launch_count_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + count_repeated_in_nested_kernel<<>>( + message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_rows, + schema, + num_fields, + repeated_info, + num_repeated, + repeated_indices, + error_flag); +} + +void launch_scan_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_repeated_in_nested_kernel<<>>( + message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_rows, + schema, + occ_prefix_sums, + repeated_indices, + occurrences, + error_flag); +} + +void launch_compute_nested_struct_locations(field_location const* child_locs, + field_location const* msg_locs, + cudf::size_type const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + cudf::size_type* nested_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (total_count == 0) return; + auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_nested_struct_locations_kernel<<>>( + child_locs, + msg_locs, + msg_row_offsets, + child_idx, + num_child_fields, + nested_locs, + nested_row_offsets, + total_count, + error_flag); +} + +void launch_compute_grandchild_parent_locations(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_grandchild_parent_locations_kernel<<>>( + parent_locs, child_locs, child_idx, num_child_fields, gc_parent_abs, num_rows, error_flag); +} + +void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (total_count == 0) return; + auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_virtual_parents_for_nested_repeated_kernel<<>>(occurrences, + row_list_offsets, + parent_locations, + virtual_row_offsets, + virtual_parent_locs, + total_count, + error_flag); +} + +void launch_compute_msg_locations_from_occurrences(repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + cudf::size_type* msg_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (total_count == 0) return; + auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_msg_locations_from_occurrences_kernel<<>>( + occurrences, list_offsets, base_offset, msg_locs, msg_row_offsets, total_count, error_flag); +} + +void launch_validate_enum_values(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + validate_enum_values_kernel<<>>( + values, valid, row_has_invalid_enum, valid_enum_values, num_valid_values, num_rows); +} + +void launch_compute_enum_string_lengths(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_enum_string_lengths_kernel<<>>( + values, valid, valid_enum_values, enum_name_offsets, num_valid_values, lengths, num_rows); +} + +void launch_copy_enum_string_chars(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + copy_enum_string_chars_kernel<<>>(values, + valid, + valid_enum_values, + enum_name_offsets, + enum_name_chars, + num_valid_values, + output_offsets, + out_chars, + num_rows); +} + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_rows == 0 || field_indices.empty()) { return; } + + bool has_required = false; + std::vector h_is_required(field_indices.size()); + for (size_t i = 0; i < field_indices.size(); ++i) { + h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; + has_required |= (h_is_required[i] != 0); + } + if (!has_required) { return; } + + rmm::device_uvector d_is_required(field_indices.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), + h_is_required.data(), + h_is_required.size() * sizeof(uint8_t), + cudaMemcpyHostToDevice, + stream.value())); + + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + check_required_fields_kernel<<>>( + locations, + d_is_required.data(), + static_cast(field_indices.size()), + num_rows, + input_null_mask, + input_offset, + parent_locs, + row_force_null, + top_row_indices, + error_flag); +} + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } + + if (top_row_indices == nullptr) { + CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), + "enum invalid-row propagation exceeded row buffer"); + thrust::transform(rmm::exec_policy(stream), + row_invalid.begin(), + row_invalid.begin() + num_items, + item_invalid.begin(), + row_invalid.begin(), + [] __device__(bool row_is_invalid, bool item_is_invalid) { + return row_is_invalid || item_is_invalid; + }); + return; + } + + rmm::device_uvector invalid_rows(num_items, stream, mr); + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_items), + invalid_rows.begin(), + [item_invalid = item_invalid.data(), top_row_indices] __device__(int idx) { + return item_invalid[idx] ? top_row_indices[idx] : -1; + }); + + auto valid_end = + thrust::remove(rmm::exec_policy(stream), invalid_rows.begin(), invalid_rows.end(), -1); + thrust::sort(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); + auto unique_end = thrust::unique(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); + thrust::for_each(rmm::exec_policy(stream), + invalid_rows.begin(), + unique_end, + [row_invalid = row_invalid.data()] __device__(int32_t row_idx) { + row_invalid[row_idx] = true; + }); +} + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + std::vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (num_items == 0 || valid_enums.empty()) { return; } + + auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), + valid_enums.data(), + valid_enums.size() * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector item_invalid(num_items, stream, mr); + thrust::fill(rmm::exec_policy(stream), item_invalid.begin(), item_invalid.end(), false); + validate_enum_values_kernel<<>>( + values.data(), + valid.data(), + item_invalid.data(), + d_valid_enums.data(), + static_cast(valid_enums.size()), + num_items); + + propagate_invalid_enum_flags_to_rows( + item_invalid, row_invalid, num_items, top_row_indices, propagate_to_rows, stream, mr); +} + } // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index 66869a4fe3..e3cfc98160 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -129,14 +129,14 @@ struct RepeatedMsgChildLocationProvider { }; template -__global__ void extract_varint_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - OutputType* out, - bool* valid, - int* error_flag, - bool has_default = false, - int64_t default_value = 0) +CUDF_KERNEL void extract_varint_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + int64_t default_value = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_items) return; @@ -181,14 +181,14 @@ __global__ void extract_varint_kernel(uint8_t const* message_data, } template -__global__ void extract_fixed_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - OutputType* out, - bool* valid, - int* error_flag, - bool has_default = false, - OutputType default_value = OutputType{}) +CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + OutputType* out, + bool* valid, + int* error_flag, + bool has_default = false, + OutputType default_value = OutputType{}) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_items) return; @@ -245,15 +245,15 @@ struct batched_scalar_desc { }; template -__global__ void extract_varint_batched_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* locations, - int num_loc_fields, - batched_scalar_desc const* descs, - int num_descs, - int num_rows, - int* error_flag) +CUDF_KERNEL void extract_varint_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) { int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); int fi = static_cast(blockIdx.y); @@ -298,15 +298,15 @@ __global__ void extract_varint_batched_kernel(uint8_t const* message_data, } template -__global__ void extract_fixed_batched_kernel(uint8_t const* message_data, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* locations, - int num_loc_fields, - batched_scalar_desc const* descs, - int num_descs, - int num_rows, - int* error_flag) +CUDF_KERNEL void extract_fixed_batched_kernel(uint8_t const* message_data, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* locations, + int num_loc_fields, + batched_scalar_desc const* descs, + int num_descs, + int num_rows, + int* error_flag) { int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); int fi = static_cast(blockIdx.y); @@ -358,11 +358,11 @@ __global__ void extract_fixed_batched_kernel(uint8_t const* message_data, // ============================================================================ template -__global__ void extract_lengths_kernel(LocationProvider loc_provider, - int total_items, - int32_t* out_lengths, - bool has_default = false, - int32_t default_length = 0) +CUDF_KERNEL void extract_lengths_kernel(LocationProvider loc_provider, + int total_items, + int32_t* out_lengths, + bool has_default = false, + int32_t default_length = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_items) return; @@ -379,15 +379,15 @@ __global__ void extract_lengths_kernel(LocationProvider loc_provider, } } template -__global__ void copy_varlen_data_kernel(uint8_t const* message_data, - LocationProvider loc_provider, - int total_items, - cudf::size_type const* output_offsets, - char* output_chars, - int* error_flag, - bool has_default = false, - uint8_t const* default_chars = nullptr, - int default_len = 0) +CUDF_KERNEL void copy_varlen_data_kernel(uint8_t const* message_data, + LocationProvider loc_provider, + int total_items, + cudf::size_type const* output_offsets, + char* output_chars, + int* error_flag, + bool has_default = false, + uint8_t const* default_chars = nullptr, + int default_len = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx >= total_items) return; @@ -409,165 +409,4 @@ __global__ void copy_varlen_data_kernel(uint8_t const* message_data, } // ============================================================================ -// Forward declarations of non-template __global__ kernels -// ============================================================================ - -__global__ void scan_all_fields_kernel(cudf::column_device_view const d_in, - field_descriptor const* field_descs, - int num_fields, - int const* field_lookup, - int field_lookup_size, - field_location* locations, - int* error_flag, - bool* row_has_invalid_data); - -__global__ void count_repeated_fields_kernel(cudf::column_device_view const d_in, - device_nested_field_descriptor const* schema, - int num_fields, - int depth_level, - repeated_field_info* repeated_info, - int num_repeated_fields, - int const* repeated_field_indices, - field_location* nested_locations, - int num_nested_fields, - int const* nested_field_indices, - int* error_flag, - int const* fn_to_rep_idx = nullptr, - int fn_to_rep_size = 0, - int const* fn_to_nested_idx = nullptr, - int fn_to_nested_size = 0); - -__global__ void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, - repeated_field_scan_desc const* scan_descs, - int num_scan_fields, - int* error_flag, - int const* fn_to_desc_idx = nullptr, - int fn_to_desc_size = 0); - -__global__ void scan_nested_message_fields_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - int num_parent_rows, - field_descriptor const* field_descs, - int num_fields, - field_location* output_locations, - int* error_flag); - -__global__ void scan_repeated_message_children_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* msg_row_offsets, - field_location const* msg_locs, - int num_occurrences, - field_descriptor const* child_descs, - int num_child_fields, - field_location* child_locs, - int* error_flag, - int const* child_lookup = nullptr, - int child_lookup_size = 0); - -__global__ void count_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - repeated_field_info* repeated_info, - int num_repeated, - int const* repeated_indices, - int* error_flag); - -__global__ void scan_repeated_in_nested_kernel(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int32_t const* occ_prefix_sums, - int const* repeated_indices, - repeated_occurrence* occurrences, - int* error_flag); - -__global__ void compute_nested_struct_locations_kernel(field_location const* child_locs, - field_location const* msg_locs, - cudf::size_type const* msg_row_offsets, - int child_idx, - int num_child_fields, - field_location* nested_locs, - cudf::size_type* nested_row_offsets, - int total_count, - int* error_flag); - -__global__ void compute_grandchild_parent_locations_kernel(field_location const* parent_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - field_location* gc_parent_abs, - int num_rows, - int* error_flag); - -__global__ void compute_virtual_parents_for_nested_repeated_kernel( - repeated_occurrence const* occurrences, - cudf::size_type const* row_list_offsets, - field_location const* parent_locations, - cudf::size_type* virtual_row_offsets, - field_location* virtual_parent_locs, - int total_count, - int* error_flag); - -__global__ void compute_msg_locations_from_occurrences_kernel( - repeated_occurrence const* occurrences, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - field_location* msg_locs, - cudf::size_type* msg_row_offsets, - int total_count, - int* error_flag); - -__global__ void extract_strided_locations_kernel(field_location const* nested_locations, - int field_idx, - int num_fields, - field_location* parent_locs, - int num_rows); - -__global__ void check_required_fields_kernel(field_location const* locations, - uint8_t const* is_required, - int num_fields, - int num_rows, - cudf::bitmask_type const* input_null_mask, - cudf::size_type input_offset, - field_location const* parent_locs, - bool* row_force_null, - int32_t const* top_row_indices, - int* error_flag); - -__global__ void validate_enum_values_kernel(int32_t const* values, - bool* valid, - bool* row_has_invalid_enum, - int32_t const* valid_enum_values, - int num_valid_values, - int num_rows); - -__global__ void compute_enum_string_lengths_kernel(int32_t const* values, - bool const* valid, - int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, - int num_valid_values, - int32_t* lengths, - int num_rows); - -__global__ void copy_enum_string_chars_kernel(int32_t const* values, - bool const* valid, - int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, - uint8_t const* enum_name_chars, - int num_valid_values, - int32_t const* output_offsets, - char* out_chars, - int num_rows); - } // namespace spark_rapids_jni::protobuf::detail From fe73ada21e923b1f1cb727f12369ef74230eb13d Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 23 Mar 2026 12:11:32 +0800 Subject: [PATCH 098/107] address comments Signed-off-by: Haoyang Li --- src/main/cpp/CMakeLists.txt | 2 +- src/main/cpp/benchmarks/protobuf_decode.cu | 26 ++-- src/main/cpp/src/ProtobufJni.cpp | 16 +-- src/main/cpp/src/protobuf/protobuf.cu | 46 +++--- src/main/cpp/src/protobuf/protobuf.hpp | 11 +- .../cpp/src/protobuf/protobuf_builders.cu | 131 +++++++++--------- .../src/protobuf/protobuf_device_helpers.cuh | 17 +-- .../src/protobuf/protobuf_host_helpers.hpp | 53 ++++--- src/main/cpp/src/protobuf/protobuf_kernels.cu | 36 ++--- .../cpp/src/protobuf/protobuf_kernels.cuh | 10 +- src/main/cpp/src/protobuf/protobuf_types.cuh | 3 - 11 files changed, 172 insertions(+), 179 deletions(-) diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 1ed650d111..cbf86bebbc 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -256,8 +256,8 @@ add_library( src/number_converter.cu src/parse_uri.cu src/protobuf/protobuf.cu - src/protobuf/protobuf_kernels.cu src/protobuf/protobuf_builders.cu + src/protobuf/protobuf_kernels.cu src/regex_rewrite_utils.cu src/row_conversion.cu src/round_float.cu diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu index 18b8b93199..4551237e89 100644 --- a/src/main/cpp/benchmarks/protobuf_decode.cu +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -207,9 +207,9 @@ struct FlatScalarCase { int num_int_fields; int num_string_fields; - spark_rapids_jni::ProtobufDecodeContext build_context() const + spark_rapids_jni::protobuf_decode_context build_context() const { - spark_rapids_jni::ProtobufDecodeContext ctx; + spark_rapids_jni::protobuf_decode_context ctx; ctx.fail_on_errors = true; // type_id cycle for integer-like fields @@ -311,9 +311,9 @@ struct FlatScalarCase { struct NestedMessageCase { int num_inner_fields; // scalar fields inside InnerMessage - spark_rapids_jni::ProtobufDecodeContext build_context() const + spark_rapids_jni::protobuf_decode_context build_context() const { - spark_rapids_jni::ProtobufDecodeContext ctx; + spark_rapids_jni::protobuf_decode_context ctx; ctx.fail_on_errors = true; // idx 0: id (int32, top-level) @@ -397,9 +397,9 @@ struct RepeatedFieldCase { int avg_labels_per_row; int avg_items_per_row; - spark_rapids_jni::ProtobufDecodeContext build_context() const + spark_rapids_jni::protobuf_decode_context build_context() const { - spark_rapids_jni::ProtobufDecodeContext ctx; + spark_rapids_jni::protobuf_decode_context ctx; ctx.fail_on_errors = true; // idx 0: id (int32, scalar) @@ -508,9 +508,9 @@ struct WideRepeatedMessageCase { int num_child_fields; int avg_items_per_row; - spark_rapids_jni::ProtobufDecodeContext build_context() const + spark_rapids_jni::protobuf_decode_context build_context() const { - spark_rapids_jni::ProtobufDecodeContext ctx; + spark_rapids_jni::protobuf_decode_context ctx; ctx.fail_on_errors = true; // idx 0: id (scalar) @@ -629,9 +629,9 @@ struct RepeatedChildListCase { return (child_idx % 4 == 3); } - spark_rapids_jni::ProtobufDecodeContext build_context() const + spark_rapids_jni::protobuf_decode_context build_context() const { - spark_rapids_jni::ProtobufDecodeContext ctx; + spark_rapids_jni::protobuf_decode_context ctx; ctx.fail_on_errors = true; // idx 0: id (scalar) @@ -782,9 +782,9 @@ struct ManyRepeatedFieldsCase { int num_repeated_int; int num_repeated_str; - spark_rapids_jni::ProtobufDecodeContext build_context() const + spark_rapids_jni::protobuf_decode_context build_context() const { - spark_rapids_jni::ProtobufDecodeContext ctx; + spark_rapids_jni::protobuf_decode_context ctx; ctx.fail_on_errors = true; int fn = 1; @@ -1281,7 +1281,7 @@ static void BM_protobuf_repeated_child_string_build(nvbench::state& state) cudaMemcpyHostToDevice, stream.value())); - spark_rapids_jni::protobuf_detail::NestedRepeatedLocationProvider nr_loc{ + spark_rapids_jni::protobuf_detail::nested_repeated_location_provider nr_loc{ row_offsets, 0, d_parent_locs.data(), c.occs.data()}; auto valid_fn = [] __device__(cudf::size_type) { return true; }; std::vector empty_default; diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 796ca3239a..61bf1f24e6 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -195,14 +195,14 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, } } - spark_rapids_jni::protobuf::ProtobufDecodeContext context{std::move(schema), - std::move(default_int_values), - std::move(default_float_values), - std::move(default_bool_values), - std::move(default_string_values), - std::move(enum_values), - std::move(enum_name_values), - static_cast(fail_on_errors)}; + spark_rapids_jni::protobuf::protobuf_decode_context context{std::move(schema), + std::move(default_int_values), + std::move(default_float_values), + std::move(default_bool_values), + std::move(default_string_values), + std::move(enum_values), + std::move(enum_name_values), + static_cast(fail_on_errors)}; auto result = spark_rapids_jni::protobuf::decode_protobuf_to_struct( *input, context, cudf::get_default_stream(), cudf::get_current_device_resource_ref()); diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 9831a6f71c..6347103376 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "nvtx_ranges.hpp" #include "protobuf/protobuf_host_helpers.hpp" #include @@ -172,7 +173,7 @@ bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_typ } } -void validate_decode_context(ProtobufDecodeContext const& context) +void validate_decode_context(protobuf_decode_context const& context) { auto const num_fields = context.schema.size(); CUDF_EXPECTS(context.default_ints.size() == num_fields, @@ -276,24 +277,26 @@ void validate_decode_context(ProtobufDecodeContext const& context) } } -ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, int schema_idx) +protobuf_field_meta_view make_field_meta_view(protobuf_decode_context const& context, + int schema_idx) { auto const idx = static_cast(schema_idx); - return ProtobufFieldMetaView{context.schema.at(idx), - cudf::data_type{context.schema.at(idx).output_type}, - context.default_ints.at(idx), - context.default_floats.at(idx), - context.default_bools.at(idx), - context.default_strings.at(idx), - context.enum_valid_values.at(idx), - context.enum_names.at(idx)}; + return protobuf_field_meta_view{context.schema.at(idx), + cudf::data_type{context.schema.at(idx).output_type}, + context.default_ints.at(idx), + context.default_floats.at(idx), + context.default_bools.at(idx), + context.default_strings.at(idx), + context.enum_valid_values.at(idx), + context.enum_names.at(idx)}; } std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, - ProtobufDecodeContext const& context, + protobuf_decode_context const& context, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { + SRJ_FUNC_RANGE(); validate_decode_context(context); auto const& schema = context.schema; auto const& default_ints = context.default_ints; @@ -321,14 +324,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto field_type = cudf::data_type{schema[i].output_type}; if (schema[i].is_repeated && field_type.id() == cudf::type_id::STRUCT) { // Repeated message field - build empty LIST with proper struct element - rmm::device_uvector offsets(1, stream, mr); - CUDF_CUDA_TRY(cudaMemsetAsync(offsets.data(), 0, sizeof(int32_t), stream.value())); - auto offsets_col = std::make_unique( - cudf::data_type{cudf::type_id::INT32}, 1, offsets.release(), rmm::device_buffer{}, 0); auto empty_struct = make_empty_struct_column_with_schema(schema, i, num_fields, stream, mr); - empty_children.push_back(cudf::make_lists_column( - 0, std::move(offsets_col), std::move(empty_struct), 0, rmm::device_buffer{})); + empty_children.push_back(make_empty_list_column(std::move(empty_struct), stream, mr)); } else if (field_type.id() == cudf::type_id::STRUCT && !schema[i].is_repeated) { // Non-repeated nested message field empty_children.push_back( @@ -726,7 +724,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto const dt = field_meta.output_type; auto const enc = static_cast(field_meta.schema.encoding); bool has_def = field_meta.schema.has_default_value; - TopLevelLocationProvider loc_provider{ + top_level_location_provider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; column_map[schema_idx] = extract_typed_column(dt, enc, @@ -769,9 +767,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? field_meta.default_int : 0; - TopLevelLocationProvider loc_provider{ + top_level_location_provider loc_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - extract_varint_kernel + extract_varint_kernel <<>>(message_data, loc_provider, num_rows, @@ -801,9 +799,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Regular protobuf STRING (length-delimited) bool has_def_str = has_def; auto const& def_str = field_meta.default_string; - TopLevelLocationProvider len_provider{ + top_level_location_provider len_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - TopLevelLocationProvider copy_provider{ + top_level_location_provider copy_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_str] __device__( cudf::size_type row) { @@ -831,9 +829,9 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; auto const& def_bytes = field_meta.default_string; - TopLevelLocationProvider len_provider{ + top_level_location_provider len_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; - TopLevelLocationProvider copy_provider{ + top_level_location_provider copy_provider{ list_offsets, base_offset, d_locations.data(), i, num_scalar}; auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__( cudf::size_type row) { diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp index b86a5960d3..4e496707a1 100644 --- a/src/main/cpp/src/protobuf/protobuf.hpp +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -72,7 +72,7 @@ struct nested_field_descriptor { bool has_default_value; // Whether this field has a default value }; -struct ProtobufDecodeContext { +struct protobuf_decode_context { std::vector schema; std::vector default_ints; std::vector default_floats; @@ -83,7 +83,7 @@ struct ProtobufDecodeContext { bool fail_on_errors; }; -struct ProtobufFieldMetaView { +struct protobuf_field_meta_view { nested_field_descriptor const& schema; cudf::data_type output_type; int64_t default_int; @@ -96,12 +96,13 @@ struct ProtobufFieldMetaView { bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type); -void validate_decode_context(ProtobufDecodeContext const& context); +void validate_decode_context(protobuf_decode_context const& context); -ProtobufFieldMetaView make_field_meta_view(ProtobufDecodeContext const& context, int schema_idx); +protobuf_field_meta_view make_field_meta_view(protobuf_decode_context const& context, + int schema_idx); std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, - ProtobufDecodeContext const& context, + protobuf_decode_context const& context, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr); diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 0d63e7cddb..776a397b4e 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -80,13 +80,13 @@ std::unique_ptr build_repeated_msg_child_varlen_column( }); if (total_data > 0) { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - child_idx, - num_child_fields}; - copy_varlen_data_kernel + repeated_msg_child_location_provider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + copy_varlen_data_kernel <<>>(message_data, loc_provider, total_count, @@ -395,13 +395,13 @@ std::unique_ptr build_repeated_msg_child_enum_string_column( rmm::device_uvector enum_values(total_count, stream, mr); rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - child_idx, - num_child_fields}; - extract_varint_kernel + repeated_msg_child_location_provider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + extract_varint_kernel <<>>(message_data, loc_provider, total_count, @@ -456,7 +456,7 @@ std::unique_ptr build_repeated_enum_string_column( // 1. Extract enum integer values from occurrences rmm::device_uvector enum_ints(total_count, stream, mr); rmm::device_uvector elem_valid(total_count, stream, mr); - RepeatedLocationProvider rep_loc{list_offsets, base_offset, d_occurrences.data()}; + repeated_location_provider rep_loc{list_offsets, base_offset, d_occurrences.data()}; extract_varint_kernel <<>>(message_data, rep_loc, @@ -582,8 +582,8 @@ std::unique_ptr build_repeated_string_column( rmm::device_uvector str_lengths(total_count, stream, mr); auto const threads = THREADS_PER_BLOCK; auto const blocks = static_cast((total_count + threads - 1u) / threads); - RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - extract_lengths_kernel + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + extract_lengths_kernel <<>>(loc_provider, total_count, str_lengths.data()); auto [str_offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( @@ -591,8 +591,8 @@ std::unique_ptr build_repeated_string_column( rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - copy_varlen_data_kernel + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + copy_varlen_data_kernel <<>>(message_data, loc_provider, total_count, @@ -903,12 +903,12 @@ std::unique_ptr build_repeated_struct_column( case cudf::type_id::UINT64: case cudf::type_id::FLOAT32: case cudf::type_id::FLOAT64: { - RepeatedMsgChildLocationProvider loc_provider{d_msg_row_offsets.data(), - 0, - d_msg_locs.data(), - d_child_locs.data(), - ci, - num_child_fields}; + repeated_msg_child_location_provider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; struct_children.push_back( extract_typed_column(dt, enc, @@ -1098,7 +1098,7 @@ std::unique_ptr build_nested_struct_column( int depth, bool propagate_invalid_rows) { - CUDF_EXPECTS(depth < MAX_NESTED_STRUCT_DECODE_DEPTH, + CUDF_EXPECTS(depth < MAX_NESTING_DEPTH, "Nested protobuf struct depth exceeds supported decode recursion limit"); if (num_rows == 0) { @@ -1210,12 +1210,12 @@ std::unique_ptr build_nested_struct_column( case cudf::type_id::UINT64: case cudf::type_id::FLOAT32: case cudf::type_id::FLOAT64: { - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; + nested_location_provider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; struct_children.push_back( extract_typed_column(dt, enc, @@ -1245,13 +1245,13 @@ std::unique_ptr build_nested_struct_column( rmm::device_uvector out(num_rows, stream, mr); rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; - NestedLocationProvider loc_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - extract_varint_kernel + nested_location_provider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel <<>>(message_data, loc_provider, num_rows, @@ -1287,18 +1287,18 @@ std::unique_ptr build_nested_struct_column( } else { bool has_def_str = has_def; auto const& def_str = default_strings[child_schema_idx]; - NestedLocationProvider len_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - NestedLocationProvider copy_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; + nested_location_provider len_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + nested_location_provider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; auto valid_fn = [plocs = d_parent_locs.data(), flocs = d_child_locations.data(), ci, @@ -1329,18 +1329,18 @@ std::unique_ptr build_nested_struct_column( // bytes (BinaryType) represented as LIST bool has_def_bytes = has_def; auto const& def_bytes = default_strings[child_schema_idx]; - NestedLocationProvider len_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; - NestedLocationProvider copy_provider{list_offsets, - base_offset, - d_parent_locs.data(), - d_child_locations.data(), - ci, - num_child_fields}; + nested_location_provider len_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + nested_location_provider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; auto valid_fn = [plocs = d_parent_locs.data(), flocs = d_child_locations.data(), ci, @@ -1555,7 +1555,8 @@ std::unique_ptr build_repeated_child_list_column( std::unique_ptr child_values; auto const rep_blocks = static_cast((total_rep_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); - NestedRepeatedLocationProvider nr_loc{row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + nested_repeated_location_provider nr_loc{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; if (elem_type_id == cudf::type_id::BOOL8 || elem_type_id == cudf::type_id::INT32 || elem_type_id == cudf::type_id::UINT32 || elem_type_id == cudf::type_id::INT64 || @@ -1593,7 +1594,7 @@ std::unique_ptr build_repeated_child_list_column( enum_valid_values[child_schema_idx], enum_names[child_schema_idx], stream, mr); rmm::device_uvector enum_values(total_rep_count, stream, mr); rmm::device_uvector valid((total_rep_count > 0 ? total_rep_count : 1), stream, mr); - extract_varint_kernel + extract_varint_kernel <<>>(message_data, nr_loc, total_rep_count, diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh index 9e01176eac..7304653e59 100644 --- a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -77,18 +77,17 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return -1; // Invalid varint } case wire_type_value(proto_wire_type::I64BIT): - // Check if there's enough data for 8 bytes - if (end - cur < 8) return -1; + if (end - cur < 8) { return -1; } return 8; case wire_type_value(proto_wire_type::I32BIT): - // Check if there's enough data for 4 bytes - if (end - cur < 4) return -1; + if (end - cur < 4) { return -1; } return 4; case wire_type_value(proto_wire_type::LEN): { uint64_t len; int n; if (!read_varint(cur, end, len, n)) return -1; - if (len > static_cast(end - cur - n) || len > static_cast(INT_MAX - n)) + if (len > static_cast(end - cur - n) || + len > static_cast(cuda::std::numeric_limits::max() - n)) return -1; return n + static_cast(len); } @@ -122,7 +121,9 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con uint64_t len; int len_bytes; if (!read_varint(cur, end, len, len_bytes)) return -1; - if (len > static_cast(INT_MAX - len_bytes)) return -1; + if (len > static_cast(cuda::std::numeric_limits::max() - len_bytes)) { + return -1; + } inner_size = len_bytes + static_cast(len); break; } @@ -172,7 +173,7 @@ __device__ inline bool get_field_data_location( int len_bytes; if (!read_varint(cur, end, len, len_bytes)) return false; if (len > static_cast(end - cur - len_bytes) || - len > static_cast(INT_MAX)) { + len > static_cast(cuda::std::numeric_limits::max())) { return false; } data_offset = len_bytes; // offset past the length prefix @@ -187,7 +188,7 @@ __device__ inline bool get_field_data_location( return true; } -__device__ __host__ inline size_t flat_index(size_t row, size_t width, size_t col) +CUDF_HOST_DEVICE inline size_t flat_index(size_t row, size_t width, size_t col) { return row * width + col; } diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index 326b07be13..8a3017124b 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -55,42 +55,37 @@ namespace spark_rapids_jni::protobuf::detail { // ============================================================================ /** - * Build a host-side direct-mapped lookup table: field_number -> local_index, - * given an array of schema indices and the schema itself. + * Build a host-side direct-mapped lookup table: field_number -> index. * Returns an empty vector if the max field number exceeds the threshold. + * + * @tparam GetFieldNumber callable (int i) -> int returning the field number for index i */ -inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, - int const* field_indices, - int num_indices) +template +inline std::vector build_lookup_table(int num_entries, GetFieldNumber get_fn) { int max_fn = 0; - for (int i = 0; i < num_indices; i++) { - max_fn = std::max(max_fn, schema[field_indices[i]].field_number); + for (int i = 0; i < num_entries; i++) { + max_fn = std::max(max_fn, get_fn(i)); } - if (max_fn > FIELD_LOOKUP_TABLE_MAX) return {}; + if (max_fn > FIELD_LOOKUP_TABLE_MAX) { return {}; } std::vector table(max_fn + 1, -1); - for (int i = 0; i < num_indices; i++) { - table[schema[field_indices[i]].field_number] = i; + for (int i = 0; i < num_entries; i++) { + table[get_fn(i)] = i; } return table; } -/** - * Build a host-side direct-mapped lookup table: field_number -> field_index. - * Returns an empty vector if the max field number exceeds the threshold. - */ +inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, + int const* field_indices, + int num_indices) +{ + return build_lookup_table(num_indices, + [&](int i) { return schema[field_indices[i]].field_number; }); +} + inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) { - int max_fn = 0; - for (int i = 0; i < num_fields; i++) { - max_fn = std::max(max_fn, descs[i].field_number); - } - if (max_fn > FIELD_LOOKUP_TABLE_MAX) return {}; - std::vector table(max_fn + 1, -1); - for (int i = 0; i < num_fields; i++) { - table[descs[i].field_number] = i; - } - return table; + return build_lookup_table(num_fields, [&](int i) { return descs[i].field_number; }); } template @@ -922,23 +917,23 @@ inline std::unique_ptr build_repeated_scalar_column( constexpr bool is_floating_point = std::is_same_v || std::is_same_v; bool use_fixed_kernel = is_floating_point || (encoding == encoding_value(proto_encoding::FIXED)); - RepeatedLocationProvider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; if (use_fixed_kernel) { if constexpr (sizeof(T) == 4) { - extract_fixed_kernel + extract_fixed_kernel <<>>( message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { - extract_fixed_kernel + extract_fixed_kernel <<>>( message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } } else if (zigzag) { - extract_varint_kernel + extract_varint_kernel <<>>( message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } else { - extract_varint_kernel + extract_varint_kernel <<>>( message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); } diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index 174420dd7b..86493c95c3 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -17,7 +17,9 @@ #include "protobuf/protobuf_kernels.cuh" #include +#include #include +#include #include #include @@ -125,7 +127,7 @@ CUDF_KERNEL void scan_all_fields_kernel( return; } if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { + len > static_cast(cuda::std::numeric_limits::max())) { set_error_once(error_flag, ERR_OVERFLOW); mark_row_error(); return; @@ -457,7 +459,7 @@ CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_i return false; } if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { + len > static_cast(cuda::std::numeric_limits::max())) { set_error_once(error_flag, ERR_OVERFLOW); return false; } @@ -672,7 +674,7 @@ CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, return; } if (len > static_cast(nested_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { + len > static_cast(cuda::std::numeric_limits::max())) { set_error_once(error_flag, ERR_OVERFLOW); return; } @@ -775,7 +777,7 @@ CUDF_KERNEL void scan_repeated_message_children_kernel( return; } if (len > static_cast(msg_end - cur - len_bytes) || - len > static_cast(INT_MAX)) { + len > static_cast(cuda::std::numeric_limits::max())) { set_error_once(error_flag, ERR_OVERFLOW); return; } @@ -1691,19 +1693,16 @@ void maybe_check_required_fields(field_location const* locations, if (num_rows == 0 || field_indices.empty()) { return; } bool has_required = false; - std::vector h_is_required(field_indices.size()); + auto h_is_required = + cudf::detail::make_host_vector(field_indices.size(), cudf::get_default_stream()); for (size_t i = 0; i < field_indices.size(); ++i) { h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; has_required |= (h_is_required[i] != 0); } if (!has_required) { return; } - rmm::device_uvector d_is_required(field_indices.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_is_required.data(), - h_is_required.data(), - h_is_required.size() * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); + auto d_is_required = cudf::detail::make_device_uvector_async( + h_is_required, stream, rmm::mr::get_current_device_resource()); auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); check_required_fields_kernel<<>>( @@ -1732,7 +1731,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ if (top_row_indices == nullptr) { CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), "enum invalid-row propagation exceeded row buffer"); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), row_invalid.begin(), row_invalid.begin() + num_items, item_invalid.begin(), @@ -1744,7 +1743,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ } rmm::device_uvector invalid_rows(num_items, stream, mr); - thrust::transform(rmm::exec_policy(stream), + thrust::transform(rmm::exec_policy_nosync(stream), thrust::make_counting_iterator(0), thrust::make_counting_iterator(num_items), invalid_rows.begin(), @@ -1753,10 +1752,11 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ }); auto valid_end = - thrust::remove(rmm::exec_policy(stream), invalid_rows.begin(), invalid_rows.end(), -1); - thrust::sort(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); - auto unique_end = thrust::unique(rmm::exec_policy(stream), invalid_rows.begin(), valid_end); - thrust::for_each(rmm::exec_policy(stream), + thrust::remove(rmm::exec_policy_nosync(stream), invalid_rows.begin(), invalid_rows.end(), -1); + thrust::sort(rmm::exec_policy_nosync(stream), invalid_rows.begin(), valid_end); + auto unique_end = + thrust::unique(rmm::exec_policy_nosync(stream), invalid_rows.begin(), valid_end); + thrust::for_each(rmm::exec_policy_nosync(stream), invalid_rows.begin(), unique_end, [row_invalid = row_invalid.data()] __device__(int32_t row_idx) { @@ -1785,7 +1785,7 @@ void validate_enum_and_propagate_rows(rmm::device_uvector const& values stream.value())); rmm::device_uvector item_invalid(num_items, stream, mr); - thrust::fill(rmm::exec_policy(stream), item_invalid.begin(), item_invalid.end(), false); + thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); validate_enum_values_kernel<<>>( values.data(), valid.data(), diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index e3cfc98160..d29ba09648 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -33,7 +33,7 @@ namespace spark_rapids_jni::protobuf::detail { // Data Extraction Location Providers // ============================================================================ -struct TopLevelLocationProvider { +struct top_level_location_provider { cudf::size_type const* offsets; cudf::size_type base_offset; field_location const* locations; @@ -50,7 +50,7 @@ struct TopLevelLocationProvider { } }; -struct RepeatedLocationProvider { +struct repeated_location_provider { cudf::size_type const* row_offsets; cudf::size_type base_offset; repeated_occurrence const* occurrences; @@ -63,7 +63,7 @@ struct RepeatedLocationProvider { } }; -struct NestedLocationProvider { +struct nested_location_provider { cudf::size_type const* row_offsets; cudf::size_type base_offset; field_location const* parent_locations; @@ -86,7 +86,7 @@ struct NestedLocationProvider { } }; -struct NestedRepeatedLocationProvider { +struct nested_repeated_location_provider { cudf::size_type const* row_offsets; cudf::size_type base_offset; field_location const* parent_locations; @@ -105,7 +105,7 @@ struct NestedRepeatedLocationProvider { } }; -struct RepeatedMsgChildLocationProvider { +struct repeated_msg_child_location_provider { cudf::size_type const* row_offsets; cudf::size_type base_offset; field_location const* msg_locations; diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh index e8ffa4a292..d9c5543b46 100644 --- a/src/main/cpp/src/protobuf/protobuf_types.cuh +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -40,9 +40,6 @@ constexpr int ERR_SCHEMA_TOO_LARGE = 10; constexpr int ERR_MISSING_ENUM_META = 11; constexpr int ERR_REPEATED_COUNT_MISMATCH = 12; -// Maximum supported nesting depth for recursive struct decoding. -constexpr int MAX_NESTED_STRUCT_DECODE_DEPTH = 10; - // Threshold for using a direct-mapped lookup table for field_number -> field_index. // Field numbers above this threshold fall back to linear search. constexpr int FIELD_LOOKUP_TABLE_MAX = 4096; From d01a25b3d095e25361041a722129dbb509a7ed32 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 23 Mar 2026 14:38:44 +0800 Subject: [PATCH 099/107] copyright Signed-off-by: Haoyang Li --- src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index b013a65bee..132efa3e3e 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2025-2026, NVIDIA CORPORATION. + * Copyright (c) 2026, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 54223ca2e8ac0a63f64ed7f1487fde5c327dece9 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 24 Mar 2026 14:38:38 +0800 Subject: [PATCH 100/107] fix Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 136 +++++++++--------- src/main/cpp/src/protobuf/protobuf.cu | 2 +- .../cpp/src/protobuf/protobuf_builders.cu | 2 +- ..._helpers.hpp => protobuf_host_helpers.cuh} | 10 +- src/main/cpp/src/protobuf/protobuf_kernels.cu | 18 +-- 5 files changed, 81 insertions(+), 87 deletions(-) rename src/main/cpp/src/protobuf/{protobuf_host_helpers.hpp => protobuf_host_helpers.cuh} (99%) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index 61bf1f24e6..b7a2444e9a 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -19,6 +19,63 @@ #include +namespace { + +std::vector> jni_byte_array_of_arrays_to_vectors(JNIEnv* env, + jobjectArray arr, + int num_fields) +{ + std::vector> result; + result.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(arr, i)); + if (env->ExceptionCheck()) { return {}; } + if (byte_arr == nullptr) { + result.emplace_back(); + } else { + jsize len = env->GetArrayLength(byte_arr); + jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + if (bytes == nullptr) { + env->DeleteLocalRef(byte_arr); + return {}; + } + result.emplace_back(reinterpret_cast(bytes), + reinterpret_cast(bytes) + len); + env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + env->DeleteLocalRef(byte_arr); + } + } + return result; +} + +std::vector> jni_int_array_of_arrays_to_vectors(JNIEnv* env, + jobjectArray arr, + int num_fields) +{ + std::vector> result; + result.reserve(num_fields); + for (int i = 0; i < num_fields; ++i) { + jintArray int_arr = static_cast(env->GetObjectArrayElement(arr, i)); + if (env->ExceptionCheck()) { return {}; } + if (int_arr == nullptr) { + result.emplace_back(); + } else { + jsize len = env->GetArrayLength(int_arr); + jint* ints = env->GetIntArrayElements(int_arr, nullptr); + if (ints == nullptr) { + env->DeleteLocalRef(int_arr); + return {}; + } + result.emplace_back(ints, ints + len); + env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + env->DeleteLocalRef(int_arr); + } + } + return result; +} + +} // namespace + extern "C" { JNIEXPORT jlong JNICALL @@ -111,52 +168,13 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); - // Convert default string values - std::vector> default_string_values; - default_string_values.reserve(num_fields); - for (int i = 0; i < num_fields; ++i) { - jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(default_strings, i)); - if (env->ExceptionCheck()) { return 0; } - if (byte_arr == nullptr) { - default_string_values.emplace_back(); - } else { - jsize len = env->GetArrayLength(byte_arr); - jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); - if (bytes == nullptr) { - env->DeleteLocalRef(byte_arr); - return 0; - } - default_string_values.emplace_back(reinterpret_cast(bytes), - reinterpret_cast(bytes) + len); - env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); - env->DeleteLocalRef(byte_arr); - } - } + auto default_string_values = + jni_byte_array_of_arrays_to_vectors(env, default_strings, num_fields); + if (env->ExceptionCheck()) { return 0; } - // Convert enum valid values - std::vector> enum_values; - enum_values.reserve(num_fields); - for (int i = 0; i < num_fields; ++i) { - jintArray int_arr = static_cast(env->GetObjectArrayElement(enum_valid_values, i)); - if (env->ExceptionCheck()) { return 0; } - if (int_arr == nullptr) { - enum_values.emplace_back(); - } else { - jsize len = env->GetArrayLength(int_arr); - jint* ints = env->GetIntArrayElements(int_arr, nullptr); - if (ints == nullptr) { - env->DeleteLocalRef(int_arr); - return 0; - } - enum_values.emplace_back(ints, ints + len); - env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); - env->DeleteLocalRef(int_arr); - } - } + auto enum_values = jni_int_array_of_arrays_to_vectors(env, enum_valid_values, num_fields); + if (env->ExceptionCheck()) { return 0; } - // Convert enum names (byte[][][]). For each field: - // - null => not an enum-as-string field - // - byte[][] where each byte[] is UTF-8 enum name, ordered with enum_values[field] std::vector>> enum_name_values; enum_name_values.reserve(num_fields); for (int i = 0; i < num_fields; ++i) { @@ -165,33 +183,11 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, if (names_arr == nullptr) { enum_name_values.emplace_back(); } else { - jsize num_names = env->GetArrayLength(names_arr); - std::vector> names_for_field; - names_for_field.reserve(num_names); - for (jsize j = 0; j < num_names; ++j) { - jbyteArray name_bytes = static_cast(env->GetObjectArrayElement(names_arr, j)); - if (env->ExceptionCheck()) { - env->DeleteLocalRef(names_arr); - return 0; - } - if (name_bytes == nullptr) { - names_for_field.emplace_back(); - } else { - jsize len = env->GetArrayLength(name_bytes); - jbyte* bytes = env->GetByteArrayElements(name_bytes, nullptr); - if (bytes == nullptr) { - env->DeleteLocalRef(name_bytes); - env->DeleteLocalRef(names_arr); - return 0; - } - names_for_field.emplace_back(reinterpret_cast(bytes), - reinterpret_cast(bytes) + len); - env->ReleaseByteArrayElements(name_bytes, bytes, JNI_ABORT); - env->DeleteLocalRef(name_bytes); - } - } - enum_name_values.push_back(std::move(names_for_field)); + jsize num_names = env->GetArrayLength(names_arr); + auto names_for_field = jni_byte_array_of_arrays_to_vectors(env, names_arr, num_names); env->DeleteLocalRef(names_arr); + if (env->ExceptionCheck()) { return 0; } + enum_name_values.push_back(std::move(names_for_field)); } } diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 6347103376..7bb7385ce8 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -15,7 +15,7 @@ */ #include "nvtx_ranges.hpp" -#include "protobuf/protobuf_host_helpers.hpp" +#include "protobuf/protobuf_host_helpers.cuh" #include diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 776a397b4e..a3eeac6c1f 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "protobuf/protobuf_host_helpers.hpp" +#include "protobuf/protobuf_host_helpers.cuh" #include #include diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh similarity index 99% rename from src/main/cpp/src/protobuf/protobuf_host_helpers.hpp rename to src/main/cpp/src/protobuf/protobuf_host_helpers.cuh index 8a3017124b..b794ca0121 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh @@ -20,6 +20,7 @@ #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include #include #include -#include #include #include #include @@ -635,10 +635,12 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( rmm::device_async_resource_ref mr) { int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; - rmm::device_uvector d_default(def_len, stream, mr); + rmm::device_uvector d_default(0, stream, mr); if (has_default && def_len > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync( - d_default.data(), default_bytes.data(), def_len, cudaMemcpyHostToDevice, stream.value())); + auto h_default = cudf::detail::make_host_vector(def_len, stream); + std::copy(default_bytes.begin(), default_bytes.end(), h_default.begin()); + d_default = cudf::detail::make_device_uvector_async( + h_default, stream, rmm::mr::get_current_device_resource()); } rmm::device_uvector lengths(num_rows, stream, mr); diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index 86493c95c3..53e42e8784 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -19,7 +19,6 @@ #include #include #include -#include #include #include @@ -1692,9 +1691,8 @@ void maybe_check_required_fields(field_location const* locations, { if (num_rows == 0 || field_indices.empty()) { return; } - bool has_required = false; - auto h_is_required = - cudf::detail::make_host_vector(field_indices.size(), cudf::get_default_stream()); + bool has_required = false; + auto h_is_required = cudf::detail::make_host_vector(field_indices.size(), stream); for (size_t i = 0; i < field_indices.size(); ++i) { h_is_required[i] = schema[field_indices[i]].is_required ? 1 : 0; has_required |= (h_is_required[i] != 0); @@ -1776,13 +1774,11 @@ void validate_enum_and_propagate_rows(rmm::device_uvector const& values { if (num_items == 0 || valid_enums.empty()) { return; } - auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + auto h_valid_enums = cudf::detail::make_host_vector(valid_enums.size(), stream); + std::copy(valid_enums.begin(), valid_enums.end(), h_valid_enums.begin()); + auto d_valid_enums = cudf::detail::make_device_uvector_async( + h_valid_enums, stream, rmm::mr::get_current_device_resource()); rmm::device_uvector item_invalid(num_items, stream, mr); thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); From 70a214b0e25df8856c0a6aa250e9c7615d089d3a Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 30 Mar 2026 10:32:38 +0800 Subject: [PATCH 101/107] apply comments suggestions Signed-off-by: Haoyang Li --- src/main/cpp/src/ProtobufJni.cpp | 130 ++++++++++-------- src/main/cpp/src/protobuf/protobuf.cu | 37 ++--- src/main/cpp/src/protobuf/protobuf.hpp | 16 +-- .../cpp/src/protobuf/protobuf_builders.cu | 96 ++++++------- .../src/protobuf/protobuf_device_helpers.cuh | 27 ++-- .../src/protobuf/protobuf_host_helpers.cuh | 69 ++++++---- src/main/cpp/src/protobuf/protobuf_kernels.cu | 86 ++++++------ .../cpp/src/protobuf/protobuf_kernels.cuh | 30 ++-- src/main/cpp/src/protobuf/protobuf_types.cuh | 4 +- .../jni/ProtobufSchemaDescriptorTest.java | 5 +- thirdparty/cudf | 2 +- 11 files changed, 258 insertions(+), 244 deletions(-) diff --git a/src/main/cpp/src/ProtobufJni.cpp b/src/main/cpp/src/ProtobufJni.cpp index b7a2444e9a..c5eb07f4b6 100644 --- a/src/main/cpp/src/ProtobufJni.cpp +++ b/src/main/cpp/src/ProtobufJni.cpp @@ -17,61 +17,70 @@ #include "cudf_jni_apis.hpp" #include "protobuf/protobuf.hpp" +#include #include namespace { -std::vector> jni_byte_array_of_arrays_to_vectors(JNIEnv* env, - jobjectArray arr, - int num_fields) +/** + * Convert a Java Object[] of primitive arrays into a vector-of-vectors. + * @tparam CppT Element type in the output vectors (e.g. host_vector, + * host_vector). + * @param convert Per-element callback: (JNIEnv*, jobject) -> std::vector. + * Must return an empty vector on null input. Returns std::nullopt on JNI error. + */ +template +std::vector jni_array_of_arrays_to_vectors(JNIEnv* env, + jobjectArray arr, + int num_elements, + ConvertFn convert) { - std::vector> result; - result.reserve(num_fields); - for (int i = 0; i < num_fields; ++i) { - jbyteArray byte_arr = static_cast(env->GetObjectArrayElement(arr, i)); + std::vector result; + result.reserve(num_elements); + for (int i = 0; i < num_elements; ++i) { + jobject elem = env->GetObjectArrayElement(arr, i); if (env->ExceptionCheck()) { return {}; } - if (byte_arr == nullptr) { - result.emplace_back(); - } else { - jsize len = env->GetArrayLength(byte_arr); - jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); - if (bytes == nullptr) { - env->DeleteLocalRef(byte_arr); - return {}; - } - result.emplace_back(reinterpret_cast(bytes), - reinterpret_cast(bytes) + len); - env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); - env->DeleteLocalRef(byte_arr); - } + auto vec = convert(env, elem); + if (elem != nullptr) { env->DeleteLocalRef(elem); } + if (env->ExceptionCheck()) { return {}; } + result.push_back(std::move(vec)); } return result; } -std::vector> jni_int_array_of_arrays_to_vectors(JNIEnv* env, - jobjectArray arr, - int num_fields) +cudf::detail::host_vector jni_byte_array_to_vector(JNIEnv* env, jobject obj) { - std::vector> result; - result.reserve(num_fields); - for (int i = 0; i < num_fields; ++i) { - jintArray int_arr = static_cast(env->GetObjectArrayElement(arr, i)); - if (env->ExceptionCheck()) { return {}; } - if (int_arr == nullptr) { - result.emplace_back(); - } else { - jsize len = env->GetArrayLength(int_arr); - jint* ints = env->GetIntArrayElements(int_arr, nullptr); - if (ints == nullptr) { - env->DeleteLocalRef(int_arr); - return {}; - } - result.emplace_back(ints, ints + len); - env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); - env->DeleteLocalRef(int_arr); - } + if (obj == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); } - return result; + auto byte_arr = static_cast(obj); + jsize len = env->GetArrayLength(byte_arr); + jbyte* bytes = env->GetByteArrayElements(byte_arr, nullptr); + if (bytes == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto vec = cudf::detail::make_host_vector(len, cudf::get_default_stream()); + std::copy( + reinterpret_cast(bytes), reinterpret_cast(bytes) + len, vec.begin()); + env->ReleaseByteArrayElements(byte_arr, bytes, JNI_ABORT); + return vec; +} + +cudf::detail::host_vector jni_int_array_to_vector(JNIEnv* env, jobject obj) +{ + if (obj == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto int_arr = static_cast(obj); + jsize len = env->GetArrayLength(int_arr); + jint* ints = env->GetIntArrayElements(int_arr, nullptr); + if (ints == nullptr) { + return cudf::detail::make_host_vector(0, cudf::get_default_stream()); + } + auto vec = cudf::detail::make_host_vector(len, cudf::get_default_stream()); + std::copy(ints, ints + len, vec.begin()); + env->ReleaseIntArrayElements(int_arr, ints, JNI_ABORT); + return vec; } } // namespace @@ -168,28 +177,27 @@ Java_com_nvidia_spark_rapids_jni_Protobuf_decodeToStruct(JNIEnv* env, std::vector default_int_values(n_default_ints.begin(), n_default_ints.end()); std::vector default_float_values(n_default_floats.begin(), n_default_floats.end()); - auto default_string_values = - jni_byte_array_of_arrays_to_vectors(env, default_strings, num_fields); + auto default_string_values = jni_array_of_arrays_to_vectors>( + env, default_strings, num_fields, jni_byte_array_to_vector); if (env->ExceptionCheck()) { return 0; } - auto enum_values = jni_int_array_of_arrays_to_vectors(env, enum_valid_values, num_fields); + auto enum_values = jni_array_of_arrays_to_vectors>( + env, enum_valid_values, num_fields, jni_int_array_to_vector); if (env->ExceptionCheck()) { return 0; } - std::vector>> enum_name_values; - enum_name_values.reserve(num_fields); - for (int i = 0; i < num_fields; ++i) { - jobjectArray names_arr = static_cast(env->GetObjectArrayElement(enum_names, i)); - if (env->ExceptionCheck()) { return 0; } - if (names_arr == nullptr) { - enum_name_values.emplace_back(); - } else { - jsize num_names = env->GetArrayLength(names_arr); - auto names_for_field = jni_byte_array_of_arrays_to_vectors(env, names_arr, num_names); - env->DeleteLocalRef(names_arr); - if (env->ExceptionCheck()) { return 0; } - enum_name_values.push_back(std::move(names_for_field)); - } - } + auto enum_name_values = + jni_array_of_arrays_to_vectors>>( + env, + enum_names, + num_fields, + [](JNIEnv* e, jobject obj) -> std::vector> { + if (obj == nullptr) { return {}; } + auto inner_arr = static_cast(obj); + jsize num = e->GetArrayLength(inner_arr); + return jni_array_of_arrays_to_vectors>( + e, inner_arr, num, jni_byte_array_to_vector); + }); + if (env->ExceptionCheck()) { return 0; } spark_rapids_jni::protobuf::protobuf_decode_context context{std::move(schema), std::move(default_int_values), diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 7bb7385ce8..e618201dca 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -580,7 +580,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& bool zz = (enc == proto_encoding::ZIGZAG); // STRING, LIST, and enum-as-string go to per-field path - if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) continue; + if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) { continue; } bool is_fixed = (enc == proto_encoding::FIXED); @@ -592,39 +592,40 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } int g = GRP_FALLBACK; - if (tid == cudf::type_id::INT32 && is_fixed) + if (tid == cudf::type_id::INT32 && is_fixed) { g = 9; - else if (tid == cudf::type_id::INT64 && is_fixed) + } else if (tid == cudf::type_id::INT64 && is_fixed) { g = 10; - else if (tid == cudf::type_id::UINT32 && is_fixed) + } else if (tid == cudf::type_id::UINT32 && is_fixed) { g = 9; - else if (tid == cudf::type_id::UINT64 && is_fixed) + } else if (tid == cudf::type_id::UINT64 && is_fixed) { g = 10; - else if (tid == cudf::type_id::INT32 && !zz) + } else if (tid == cudf::type_id::INT32 && !zz) { g = 0; - else if (tid == cudf::type_id::UINT32) + } else if (tid == cudf::type_id::UINT32) { g = 1; - else if (tid == cudf::type_id::INT64 && !zz) + } else if (tid == cudf::type_id::INT64 && !zz) { g = 2; - else if (tid == cudf::type_id::UINT64) + } else if (tid == cudf::type_id::UINT64) { g = 3; - else if (tid == cudf::type_id::BOOL8) + } else if (tid == cudf::type_id::BOOL8) { g = 4; - else if (tid == cudf::type_id::INT32 && zz) + } else if (tid == cudf::type_id::INT32 && zz) { g = 5; - else if (tid == cudf::type_id::INT64 && zz) + } else if (tid == cudf::type_id::INT64 && zz) { g = 6; - else if (tid == cudf::type_id::FLOAT32) + } else if (tid == cudf::type_id::FLOAT32) { g = 7; - else if (tid == cudf::type_id::FLOAT64) + } else if (tid == cudf::type_id::FLOAT64) { g = 8; + } group_lists[g].push_back(i); } // Helper: batch-extract one group using a 2D kernel, then build columns. auto do_batch = [&](std::vector const& idxs, auto kernel_launcher) { int nf = static_cast(idxs.size()); - if (nf == 0) return; + if (nf == 0) { return; } std::vector> bufs; bufs.reserve(nf); @@ -753,7 +754,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& int schema_idx = scalar_field_indices[i]; auto const field_meta = make_field_meta_view(context, schema_idx); auto const dt = field_meta.output_type; - if (dt.id() != cudf::type_id::STRING && dt.id() != cudf::type_id::LIST) continue; + if (dt.id() != cudf::type_id::STRING && dt.id() != cudf::type_id::LIST) { continue; } auto const enc = field_meta.schema.encoding; bool has_def = field_meta.schema.has_default_value; @@ -988,7 +989,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& auto& d_occurrences = *w.occurrences; // Build the appropriate column type based on element type - auto child_type_id = h_device_schema[schema_idx].output_type; + auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); // The output_type in schema is the LIST type, but we need element type // For repeated int32, output_type should indicate the element is INT32 @@ -1209,7 +1210,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Build appropriate empty child column std::unique_ptr child_col; - auto child_type_id = h_device_schema[schema_idx].output_type; + auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); if (child_type_id == cudf::type_id::STRUCT) { // Use helper to build empty struct with proper nested structure child_col = diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp index 4e496707a1..db2846409b 100644 --- a/src/main/cpp/src/protobuf/protobuf.hpp +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -18,8 +18,8 @@ #include #include +#include #include -#include #include #include @@ -77,21 +77,21 @@ struct protobuf_decode_context { std::vector default_ints; std::vector default_floats; std::vector default_bools; - std::vector> default_strings; - std::vector> enum_valid_values; - std::vector>> enum_names; + std::vector> default_strings; + std::vector> enum_valid_values; + std::vector>> enum_names; bool fail_on_errors; }; struct protobuf_field_meta_view { nested_field_descriptor const& schema; - cudf::data_type output_type; + cudf::data_type const output_type; int64_t default_int; double default_float; bool default_bool; - std::vector const& default_string; - std::vector const& enum_valid_values; - std::vector> const& enum_names; + cudf::detail::host_vector const& default_string; + cudf::detail::host_vector const& enum_valid_values; + std::vector> const& enum_names; }; bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_type const& type); diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index a3eeac6c1f..cb9dd7b54d 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -247,19 +247,16 @@ struct enum_string_lookup_tables { }; enum_string_lookup_tables make_enum_string_lookup_tables( - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - rmm::device_uvector d_valid_enums(valid_enums.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_valid_enums.data(), - valid_enums.data(), - valid_enums.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + auto d_valid_enums = cudf::detail::make_device_uvector_async( + valid_enums, stream, rmm::mr::get_current_device_resource()); - std::vector h_name_offsets(valid_enums.size() + 1, 0); + auto h_name_offsets = cudf::detail::make_host_vector(valid_enums.size() + 1, stream); + std::fill(h_name_offsets.begin(), h_name_offsets.end(), 0); int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { total_name_chars += static_cast(enum_name_bytes[k].size()); @@ -268,8 +265,8 @@ enum_string_lookup_tables make_enum_string_lookup_tables( h_name_offsets[k + 1] = static_cast(total_name_chars); } - std::vector h_name_chars(total_name_chars); - int32_t cursor = 0; + auto h_name_chars = cudf::detail::make_host_vector(total_name_chars, stream); + int32_t cursor = 0; for (auto const& name : enum_name_bytes) { if (!name.empty()) { std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); @@ -277,21 +274,16 @@ enum_string_lookup_tables make_enum_string_lookup_tables( } } - rmm::device_uvector d_name_offsets(h_name_offsets.size(), stream, mr); - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_offsets.data(), - h_name_offsets.data(), - h_name_offsets.size() * sizeof(int32_t), - cudaMemcpyHostToDevice, - stream.value())); + auto d_name_offsets = cudf::detail::make_device_uvector_async( + h_name_offsets, stream, rmm::mr::get_current_device_resource()); - rmm::device_uvector d_name_chars(total_name_chars, stream, mr); - if (total_name_chars > 0) { - CUDF_CUDA_TRY(cudaMemcpyAsync(d_name_chars.data(), - h_name_chars.data(), - total_name_chars * sizeof(uint8_t), - cudaMemcpyHostToDevice, - stream.value())); - } + auto d_name_chars = [&]() { + if (total_name_chars > 0) { + return cudf::detail::make_device_uvector_async( + h_name_chars, stream, rmm::mr::get_current_device_resource()); + } + return rmm::device_uvector(0, stream, mr); + }(); return {std::move(d_valid_enums), std::move(d_name_offsets), std::move(d_name_chars)}; } @@ -339,8 +331,8 @@ std::unique_ptr build_enum_string_values_column( std::unique_ptr build_enum_string_column( rmm::device_uvector& enum_values, rmm::device_uvector& valid, - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, rmm::device_uvector& d_row_force_null, int num_rows, rmm::cuda_stream_view stream, @@ -380,8 +372,8 @@ std::unique_ptr build_repeated_msg_child_enum_string_column( int child_idx, int num_child_fields, int total_count, - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, rmm::device_uvector& d_row_force_null, int32_t const* top_row_indices, bool propagate_invalid_rows, @@ -442,8 +434,8 @@ std::unique_ptr build_repeated_enum_string_column( rmm::device_uvector& d_occurrences, int total_count, int num_rows, - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, @@ -652,9 +644,9 @@ std::unique_ptr build_nested_struct_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, int num_rows, @@ -680,9 +672,9 @@ std::unique_ptr build_repeated_child_list_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, @@ -707,10 +699,10 @@ std::unique_ptr build_repeated_struct_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, + std::vector> const& default_strings, std::vector const& schema, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error_top, rmm::cuda_stream_view stream, @@ -1086,9 +1078,9 @@ std::unique_ptr build_nested_struct_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, int num_rows, @@ -1437,9 +1429,9 @@ std::unique_ptr build_repeated_child_list_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, @@ -1459,7 +1451,7 @@ std::unique_ptr build_repeated_child_list_column( device_nested_field_descriptor rep_desc; rep_desc.field_number = schema[child_schema_idx].field_number; rep_desc.wire_type = static_cast(schema[child_schema_idx].wire_type); - rep_desc.output_type = schema[child_schema_idx].output_type; + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); rep_desc.is_repeated = true; rep_desc.parent_idx = -1; rep_desc.depth = 0; @@ -1573,7 +1565,7 @@ std::unique_ptr build_repeated_child_list_column( 0, 0.0, false, - std::vector{}, + cudf::detail::make_host_vector(0, stream), child_schema_idx, enum_valid_values, enum_names, @@ -1630,10 +1622,10 @@ std::unique_ptr build_repeated_child_list_column( child_values = make_null_column(cudf::data_type{elem_type_id}, total_rep_count, stream, mr); } } else { - bool as_bytes = (elem_type_id == cudf::type_id::LIST); - auto valid_fn = [] __device__(cudf::size_type) { return true; }; - std::vector empty_default; - child_values = extract_and_build_string_or_bytes_column(as_bytes, + bool as_bytes = (elem_type_id == cudf::type_id::LIST); + auto valid_fn = [] __device__(cudf::size_type) { return true; }; + auto empty_default = cudf::detail::make_host_vector(0, stream); + child_values = extract_and_build_string_or_bytes_column(as_bytes, message_data, total_rep_count, nr_loc, diff --git a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh index 7304653e59..1a3e4013d1 100644 --- a/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_device_helpers.cuh @@ -77,18 +77,21 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con return -1; // Invalid varint } case wire_type_value(proto_wire_type::I64BIT): + // Check if there's enough data for 8 bytes if (end - cur < 8) { return -1; } return 8; case wire_type_value(proto_wire_type::I32BIT): + // Check if there's enough data for 4 bytes if (end - cur < 4) { return -1; } return 4; case wire_type_value(proto_wire_type::LEN): { uint64_t len; int n; - if (!read_varint(cur, end, len, n)) return -1; + if (!read_varint(cur, end, len, n)) { return -1; } if (len > static_cast(end - cur - n) || - len > static_cast(cuda::std::numeric_limits::max() - n)) + len > static_cast(cuda::std::numeric_limits::max() - n)) { return -1; + } return n + static_cast(len); } case wire_type_value(proto_wire_type::SGROUP): { @@ -97,7 +100,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con while (cur < end && depth > 0) { uint64_t key; int key_bytes; - if (!read_varint(cur, end, key, key_bytes)) return -1; + if (!read_varint(cur, end, key, key_bytes)) { return -1; } cur += key_bytes; int inner_wt = static_cast(key & 0x7); @@ -105,14 +108,14 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con --depth; if (depth == 0) { return static_cast(cur - start); } } else if (inner_wt == wire_type_value(proto_wire_type::SGROUP)) { - if (++depth > 32) return -1; + if (++depth > 32) { return -1; } } else { int inner_size = -1; switch (inner_wt) { case wire_type_value(proto_wire_type::VARINT): { uint64_t dummy; int vbytes; - if (!read_varint(cur, end, dummy, vbytes)) return -1; + if (!read_varint(cur, end, dummy, vbytes)) { return -1; } inner_size = vbytes; break; } @@ -120,7 +123,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con case wire_type_value(proto_wire_type::LEN): { uint64_t len; int len_bytes; - if (!read_varint(cur, end, len, len_bytes)) return -1; + if (!read_varint(cur, end, len, len_bytes)) { return -1; } if (len > static_cast(cuda::std::numeric_limits::max() - len_bytes)) { return -1; } @@ -130,7 +133,7 @@ __device__ inline int get_wire_type_size(int wt, uint8_t const* cur, uint8_t con case wire_type_value(proto_wire_type::I32BIT): inner_size = 4; break; default: return -1; } - if (inner_size < 0 || cur + inner_size > end) return -1; + if (inner_size < 0 || cur + inner_size > end) { return -1; } cur += inner_size; } } @@ -153,9 +156,9 @@ __device__ inline bool skip_field(uint8_t const* cur, if (wt == wire_type_value(proto_wire_type::EGROUP)) { return false; } int size = get_wire_type_size(wt, cur, end); - if (size < 0) return false; + if (size < 0) { return false; } // Ensure we don't skip past the end of the buffer - if (cur + size > end) return false; + if (cur + size > end) { return false; } out_cur = cur + size; return true; } @@ -171,7 +174,7 @@ __device__ inline bool get_field_data_location( // For length-delimited, read the length prefix uint64_t len; int len_bytes; - if (!read_varint(cur, end, len, len_bytes)) return false; + if (!read_varint(cur, end, len, len_bytes)) { return false; } if (len > static_cast(end - cur - len_bytes) || len > static_cast(cuda::std::numeric_limits::max())) { return false; @@ -181,7 +184,7 @@ __device__ inline bool get_field_data_location( } else { // For fixed-size and varint fields int field_size = get_wire_type_size(wt, cur, end); - if (field_size < 0) return false; + if (field_size < 0) { return false; } data_offset = 0; data_length = field_size; } @@ -284,7 +287,7 @@ __device__ __forceinline__ int lookup_field(int field_number, return lookup_table[field_number]; } for (int f = 0; f < num_fields; f++) { - if (field_descs[f].field_number == field_number) return f; + if (field_descs[f].field_number == field_number) { return f; } } return -1; } diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh index b794ca0121..0f204c4d84 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh @@ -259,6 +259,20 @@ std::unique_ptr make_empty_list_column(std::unique_ptr +inline cudf::type_id get_output_type_id(FieldT const& field) +{ + if constexpr (std::is_same_v) { + return static_cast(field.output_type_id); + } else { + return field.output_type; + } +} + template std::unique_ptr make_empty_struct_column_with_schema( SchemaT const& schema, @@ -271,7 +285,7 @@ std::unique_ptr make_empty_struct_column_with_schema( std::vector> children; for (int child_idx : child_indices) { - auto child_type = cudf::data_type{schema[child_idx].output_type}; + auto child_type = cudf::data_type{get_output_type_id(schema[child_idx])}; std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { @@ -480,7 +494,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ void validate_enum_and_propagate_rows(rmm::device_uvector const& values, rmm::device_uvector& valid, - std::vector const& valid_enums, + cudf::detail::host_vector const& valid_enums, rmm::device_uvector& row_invalid, int num_items, int32_t const* top_row_indices, @@ -506,8 +520,8 @@ std::unique_ptr make_null_list_column_with_child( std::unique_ptr build_enum_string_column( rmm::device_uvector& enum_values, rmm::device_uvector& valid, - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, rmm::device_uvector& d_row_force_null, int num_rows, rmm::cuda_stream_view stream, @@ -525,8 +539,8 @@ std::unique_ptr build_repeated_enum_string_column( rmm::device_uvector& d_occurrences, int total_count, int num_rows, - std::vector const& valid_enums, - std::vector> const& enum_name_bytes, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, @@ -559,9 +573,9 @@ std::unique_ptr build_nested_struct_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, int num_rows, @@ -584,9 +598,9 @@ std::unique_ptr build_repeated_child_list_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, @@ -611,10 +625,10 @@ std::unique_ptr build_repeated_struct_column( std::vector const& default_ints, std::vector const& default_floats, std::vector const& default_bools, - std::vector> const& default_strings, + std::vector> const& default_strings, std::vector const& schema, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error_top, rmm::cuda_stream_view stream, @@ -629,7 +643,7 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( CopyProvider const& copy_provider, ValidityFn validity_fn, bool has_default, - std::vector const& default_bytes, + cudf::detail::host_vector const& default_bytes, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) @@ -637,10 +651,8 @@ inline std::unique_ptr extract_and_build_string_or_bytes_column( int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; rmm::device_uvector d_default(0, stream, mr); if (has_default && def_len > 0) { - auto h_default = cudf::detail::make_host_vector(def_len, stream); - std::copy(default_bytes.begin(), default_bytes.end(), h_default.begin()); d_default = cudf::detail::make_device_uvector_async( - h_default, stream, rmm::mr::get_current_device_resource()); + default_bytes, stream, rmm::mr::get_current_device_resource()); } rmm::device_uvector lengths(num_rows, stream, mr); @@ -701,10 +713,10 @@ inline std::unique_ptr extract_typed_column( int64_t default_int, double default_float, bool default_bool, - std::vector const& default_string, + cudf::detail::host_vector const& default_string, int schema_idx, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, rmm::device_uvector& d_row_force_null, rmm::device_uvector& d_error, rmm::cuda_stream_view stream, @@ -866,6 +878,7 @@ inline std::unique_ptr build_repeated_scalar_column( rmm::device_async_resource_ref mr) { auto const input_null_count = binary_input.null_count(); + auto const field_type_id = static_cast(field_desc.output_type_id); if (total_count == 0) { // All rows have count=0, but we still need to check input nulls @@ -876,9 +889,8 @@ inline std::unique_ptr build_repeated_scalar_column( offsets.release(), rmm::device_buffer{}, 0); - auto elem_type = - field_desc.output_type == cudf::type_id::LIST ? cudf::type_id::UINT8 : field_desc.output_type; - auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + auto elem_type = field_type_id == cudf::type_id::LIST ? cudf::type_id::UINT8 : field_type_id; + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); if (input_null_count > 0) { // Copy input null mask - only input nulls produce output nulls @@ -945,11 +957,8 @@ inline std::unique_ptr build_repeated_scalar_column( list_offs.release(), rmm::device_buffer{}, 0); - auto child_col = std::make_unique(cudf::data_type{field_desc.output_type}, - total_count, - values.release(), - rmm::device_buffer{}, - 0); + auto child_col = std::make_unique( + cudf::data_type{field_type_id}, total_count, values.release(), rmm::device_buffer{}, 0); // Only rows where INPUT is null should produce null output // Rows with valid input but count=0 should produce empty array [] diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index 53e42e8784..a3907ce4e9 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -69,7 +69,7 @@ CUDF_KERNEL void scan_all_fields_kernel( { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; + if (row >= in.size()) { return; } auto mark_row_error = [&]() { if (row_has_invalid_data != nullptr) { row_has_invalid_data[row] = true; } @@ -375,7 +375,7 @@ CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_i { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; + if (row >= in.size()) { return; } // Initialize repeated counts to 0 for (int f = 0; f < num_repeated_fields; f++) { @@ -484,14 +484,14 @@ CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_i if (i >= 0) { int schema_idx = nested_field_indices[i]; if (schema[schema_idx].depth == depth_level) { - if (!handle_nested(i)) return; + if (!handle_nested(i)) { return; } } } } else { for (int i = 0; i < num_nested_fields; i++) { int schema_idx = nested_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - if (!handle_nested(i)) return; + if (!handle_nested(i)) { return; } } } } @@ -519,7 +519,7 @@ CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view c { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) return; + if (row >= in.size()) { return; } if (in.nullable() && in.is_null(row)) { return; } @@ -572,12 +572,12 @@ CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view c if (fn_to_desc_idx != nullptr && fn > 0 && fn < fn_to_desc_size) { int f = fn_to_desc_idx[fn]; if (f >= 0 && f < num_scan_fields) { - if (!try_scan(f)) return; + if (!try_scan(f)) { return; } } } else { for (int f = 0; f < num_scan_fields; f++) { if (scan_descs[f].field_number == fn) { - if (!try_scan(f)) return; + if (!try_scan(f)) { return; } } } } @@ -623,7 +623,7 @@ CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_parent_rows) return; + if (row >= num_parent_rows) { return; } for (int f = 0; f < num_fields; f++) { output_locations[flat_index( @@ -727,7 +727,7 @@ CUDF_KERNEL void scan_repeated_message_children_kernel( int child_lookup_size) { auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (occ_idx >= num_occurrences) return; + if (occ_idx >= num_occurrences) { return; } // Initialize child locations to not found for (int f = 0; f < num_child_fields; f++) { @@ -737,7 +737,7 @@ CUDF_KERNEL void scan_repeated_message_children_kernel( } auto const& msg_loc = msg_locs[occ_idx]; - if (msg_loc.offset < 0) return; + if (msg_loc.offset < 0) { return; } cudf::size_type row_offset = msg_row_offsets[occ_idx]; int64_t msg_start_off = static_cast(row_offset) + msg_loc.offset; @@ -853,7 +853,7 @@ CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; + if (row >= num_rows) { return; } // Initialize counts for (int ri = 0; ri < num_repeated; ri++) { @@ -863,7 +863,7 @@ CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, } auto const& parent_loc = parent_locs[row]; - if (parent_loc.offset < 0) return; + if (parent_loc.offset < 0) { return; } cudf::size_type row_off; row_off = row_offsets[row] - base_offset; @@ -932,7 +932,7 @@ CUDF_KERNEL void scan_repeated_in_nested_kernel(uint8_t const* message_data, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; + if (row >= num_rows) { return; } int write_idx = occ_prefix_sums[row]; int write_end = occ_prefix_sums[row + 1]; @@ -1006,7 +1006,7 @@ CUDF_KERNEL void compute_nested_struct_locations_kernel( int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; + if (idx >= total_count) { return; } nested_locs[idx] = child_locs[flat_index(static_cast(idx), static_cast(num_child_fields), @@ -1037,7 +1037,7 @@ CUDF_KERNEL void compute_grandchild_parent_locations_kernel( int* error_flag) { int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= num_rows) return; + if (row >= num_rows) { return; } auto const& parent_loc = parent_locs[row]; auto const& child_loc = child_locs[flat_index(static_cast(row), @@ -1075,7 +1075,7 @@ CUDF_KERNEL void compute_virtual_parents_for_nested_repeated_kernel( int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; + if (idx >= total_count) { return; } auto const& occ = occurrences[idx]; auto const& ploc = parent_locations[occ.row_idx]; @@ -1113,7 +1113,7 @@ CUDF_KERNEL void compute_msg_locations_from_occurrences_kernel( int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) return; + if (idx >= total_count) { return; } auto const& occ = occurrences[idx]; auto row_offset = static_cast(list_offsets[occ.row_idx]) - base_offset; @@ -1139,7 +1139,7 @@ CUDF_KERNEL void extract_strided_locations_kernel(field_location const* nested_l int num_rows) { int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= num_rows) return; + if (row >= num_rows) { return; } parent_locs[row] = nested_locations[flat_index( static_cast(row), static_cast(num_fields), static_cast(field_idx))]; } @@ -1165,11 +1165,11 @@ CUDF_KERNEL void check_required_fields_kernel( int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; + if (row >= num_rows) { return; } if (input_null_mask != nullptr && !cudf::bit_is_set(input_null_mask, row + input_offset)) { return; } - if (parent_locs != nullptr && parent_locs[row].offset < 0) return; + if (parent_locs != nullptr && parent_locs[row].offset < 0) { return; } for (int f = 0; f < num_fields; f++) { if (is_required[f] != 0 && locations[flat_index(static_cast(row), @@ -1211,10 +1211,10 @@ CUDF_KERNEL void validate_enum_values_kernel( int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; + if (row >= num_rows) { return; } // Skip if already invalid (field was missing) - missing field is not an enum error - if (!valid[row]) return; + if (!valid[row]) { return; } int32_t val = values[row]; @@ -1259,7 +1259,7 @@ CUDF_KERNEL void compute_enum_string_lengths_kernel( int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; + if (row >= num_rows) { return; } if (!valid[row]) { lengths[row] = 0; @@ -1301,8 +1301,8 @@ CUDF_KERNEL void copy_enum_string_chars_kernel( int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) return; - if (!valid[row]) return; + if (row >= num_rows) { return; } + if (!valid[row]) { return; } int32_t val = values[row]; int left = 0; @@ -1349,7 +1349,7 @@ void launch_scan_all_fields(cudf::column_device_view const& d_in, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_all_fields_kernel<<>>(d_in, field_descs, @@ -1379,7 +1379,7 @@ void launch_count_repeated_fields(cudf::column_device_view const& d_in, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); count_repeated_fields_kernel<<>>( d_in, @@ -1408,7 +1408,7 @@ void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_all_repeated_occurrences_kernel<<>>( d_in, scan_descs, num_scan_fields, error_flag, fn_to_desc_idx, fn_to_desc_size); @@ -1421,7 +1421,7 @@ void launch_extract_strided_locations(field_location const* nested_locations, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); extract_strided_locations_kernel<<>>( nested_locations, field_idx, num_fields, parent_locs, num_rows); @@ -1439,7 +1439,7 @@ void launch_scan_nested_message_fields(uint8_t const* message_data, int* error_flag, rmm::cuda_stream_view stream) { - if (num_parent_rows == 0) return; + if (num_parent_rows == 0) { return; } auto const blocks = static_cast((num_parent_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_nested_message_fields_kernel<<>>( @@ -1468,7 +1468,7 @@ void launch_scan_repeated_message_children(uint8_t const* message_data, int child_lookup_size, rmm::cuda_stream_view stream) { - if (num_occurrences == 0) return; + if (num_occurrences == 0) { return; } auto const blocks = static_cast((num_occurrences + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_repeated_message_children_kernel<<>>( @@ -1499,7 +1499,7 @@ void launch_count_repeated_in_nested(uint8_t const* message_data, int* error_flag, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); count_repeated_in_nested_kernel<<>>( message_data, @@ -1529,7 +1529,7 @@ void launch_scan_repeated_in_nested(uint8_t const* message_data, int* error_flag, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_repeated_in_nested_kernel<<>>( message_data, @@ -1556,7 +1556,7 @@ void launch_compute_nested_struct_locations(field_location const* child_locs, int* error_flag, rmm::cuda_stream_view stream) { - if (total_count == 0) return; + if (total_count == 0) { return; } auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_nested_struct_locations_kernel<<>>( child_locs, @@ -1579,7 +1579,7 @@ void launch_compute_grandchild_parent_locations(field_location const* parent_loc int* error_flag, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_grandchild_parent_locations_kernel<<>>( parent_locs, child_locs, child_idx, num_child_fields, gc_parent_abs, num_rows, error_flag); @@ -1594,7 +1594,7 @@ void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence cons int* error_flag, rmm::cuda_stream_view stream) { - if (total_count == 0) return; + if (total_count == 0) { return; } auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_virtual_parents_for_nested_repeated_kernel<<((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_msg_locations_from_occurrences_kernel<<>>( occurrences, list_offsets, base_offset, msg_locs, msg_row_offsets, total_count, error_flag); @@ -1631,7 +1631,7 @@ void launch_validate_enum_values(int32_t const* values, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); validate_enum_values_kernel<<>>( values, valid, row_has_invalid_enum, valid_enum_values, num_valid_values, num_rows); @@ -1646,7 +1646,7 @@ void launch_compute_enum_string_lengths(int32_t const* values, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_enum_string_lengths_kernel<<>>( values, valid, valid_enum_values, enum_name_offsets, num_valid_values, lengths, num_rows); @@ -1663,7 +1663,7 @@ void launch_copy_enum_string_chars(int32_t const* values, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) return; + if (num_rows == 0) { return; } auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); copy_enum_string_chars_kernel<<>>(values, valid, @@ -1764,7 +1764,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ void validate_enum_and_propagate_rows(rmm::device_uvector const& values, rmm::device_uvector& valid, - std::vector const& valid_enums, + cudf::detail::host_vector const& valid_enums, rmm::device_uvector& row_invalid, int num_items, int32_t const* top_row_indices, @@ -1775,10 +1775,8 @@ void validate_enum_and_propagate_rows(rmm::device_uvector const& values if (num_items == 0 || valid_enums.empty()) { return; } auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); - auto h_valid_enums = cudf::detail::make_host_vector(valid_enums.size(), stream); - std::copy(valid_enums.begin(), valid_enums.end(), h_valid_enums.begin()); auto d_valid_enums = cudf::detail::make_device_uvector_async( - h_valid_enums, stream, rmm::mr::get_current_device_resource()); + valid_enums, stream, rmm::mr::get_current_device_resource()); rmm::device_uvector item_invalid(num_items, stream, mr); thrust::fill(rmm::exec_policy_nosync(stream), item_invalid.begin(), item_invalid.end(), false); diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index d29ba09648..0804ac4e59 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -139,7 +139,7 @@ CUDF_KERNEL void extract_varint_kernel(uint8_t const* message_data, int64_t default_value = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; + if (idx >= total_items) { return; } int32_t data_offset = 0; auto loc = loc_provider.get(idx, data_offset); @@ -157,9 +157,9 @@ CUDF_KERNEL void extract_varint_kernel(uint8_t const* message_data, if (loc.offset < 0) { if (has_default) { write_value(&out[idx], static_cast(default_value)); - if (valid) valid[idx] = true; + if (valid) { valid[idx] = true; } } else { - if (valid) valid[idx] = false; + if (valid) { valid[idx] = false; } } return; } @@ -171,13 +171,13 @@ CUDF_KERNEL void extract_varint_kernel(uint8_t const* message_data, int n; if (!read_varint(cur, cur_end, v, n)) { set_error_once(error_flag, ERR_VARINT); - if (valid) valid[idx] = false; + if (valid) { valid[idx] = false; } return; } if constexpr (ZigZag) { v = (v >> 1) ^ (-(v & 1)); } write_value(&out[idx], v); - if (valid) valid[idx] = true; + if (valid) { valid[idx] = true; } } template @@ -191,7 +191,7 @@ CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, OutputType default_value = OutputType{}) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; + if (idx >= total_items) { return; } int32_t data_offset = 0; auto loc = loc_provider.get(idx, data_offset); @@ -199,9 +199,9 @@ CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, if (loc.offset < 0) { if (has_default) { out[idx] = default_value; - if (valid) valid[idx] = true; + if (valid) { valid[idx] = true; } } else { - if (valid) valid[idx] = false; + if (valid) { valid[idx] = false; } } return; } @@ -212,7 +212,7 @@ CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, if constexpr (WT == wire_type_value(proto_wire_type::I32BIT)) { if (loc.length < 4) { set_error_once(error_flag, ERR_FIXED_LEN); - if (valid) valid[idx] = false; + if (valid) { valid[idx] = false; } return; } uint32_t raw = load_le(cur); @@ -220,7 +220,7 @@ CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, } else { if (loc.length < 8) { set_error_once(error_flag, ERR_FIXED_LEN); - if (valid) valid[idx] = false; + if (valid) { valid[idx] = false; } return; } uint64_t raw = load_le(cur); @@ -228,7 +228,7 @@ CUDF_KERNEL void extract_fixed_kernel(uint8_t const* message_data, } out[idx] = value; - if (valid) valid[idx] = true; + if (valid) { valid[idx] = true; } } // ============================================================================ @@ -257,7 +257,7 @@ CUDF_KERNEL void extract_varint_batched_kernel(uint8_t const* message_data, { int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); int fi = static_cast(blockIdx.y); - if (row >= num_rows || fi >= num_descs) return; + if (row >= num_rows || fi >= num_descs) { return; } auto const& desc = descs[fi]; auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; @@ -310,7 +310,7 @@ CUDF_KERNEL void extract_fixed_batched_kernel(uint8_t const* message_data, { int row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); int fi = static_cast(blockIdx.y); - if (row >= num_rows || fi >= num_descs) return; + if (row >= num_rows || fi >= num_descs) { return; } auto const& desc = descs[fi]; auto loc = locations[row * num_loc_fields + desc.loc_field_idx]; @@ -365,7 +365,7 @@ CUDF_KERNEL void extract_lengths_kernel(LocationProvider loc_provider, int32_t default_length = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; + if (idx >= total_items) { return; } int32_t data_offset = 0; auto loc = loc_provider.get(idx, data_offset); @@ -390,7 +390,7 @@ CUDF_KERNEL void copy_varlen_data_kernel(uint8_t const* message_data, int default_len = 0) { auto idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (idx >= total_items) return; + if (idx >= total_items) { return; } int32_t data_offset = 0; auto loc = loc_provider.get(idx, data_offset); diff --git a/src/main/cpp/src/protobuf/protobuf_types.cuh b/src/main/cpp/src/protobuf/protobuf_types.cuh index d9c5543b46..c575447ded 100644 --- a/src/main/cpp/src/protobuf/protobuf_types.cuh +++ b/src/main/cpp/src/protobuf/protobuf_types.cuh @@ -98,7 +98,7 @@ struct device_nested_field_descriptor { int parent_idx; int depth; int wire_type; - cudf::type_id output_type; + int output_type_id; int encoding; bool is_repeated; bool is_required; @@ -114,7 +114,7 @@ struct device_nested_field_descriptor { parent_idx(src.parent_idx), depth(src.depth), wire_type(static_cast(src.wire_type)), - output_type(src.output_type), + output_type_id(static_cast(src.output_type)), encoding(static_cast(src.encoding)), is_repeated(src.is_repeated), is_required(src.is_required), diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java index bf56c2662b..eabe8e58d0 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java @@ -36,12 +36,15 @@ private ProtobufSchemaDescriptor makeDescriptor( int encoding, int[] enumValidValues, byte[][] enumNames) { + int outputType = (encoding == Protobuf.ENC_ENUM_STRING) + ? ai.rapids.cudf.DType.STRING.getTypeId().getNativeId() + : ai.rapids.cudf.DType.INT32.getTypeId().getNativeId(); return new ProtobufSchemaDescriptor( new int[]{1}, new int[]{-1}, new int[]{0}, new int[]{Protobuf.WT_VARINT}, - new int[]{ai.rapids.cudf.DType.INT32.getTypeId().getNativeId()}, + new int[]{outputType}, new int[]{encoding}, new boolean[]{isRepeated}, new boolean[]{false}, diff --git a/thirdparty/cudf b/thirdparty/cudf index 3384db296f..b13ada2c89 160000 --- a/thirdparty/cudf +++ b/thirdparty/cudf @@ -1 +1 @@ -Subproject commit 3384db296ff5663abd22ce65994b1a852fed73e3 +Subproject commit b13ada2c8970be62e713d197e0ad3fa596ea32ef From 6e00831afd8a52398ece84fc9d705b9cedf1346c Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Mon, 30 Mar 2026 11:28:30 +0800 Subject: [PATCH 102/107] apply refactor on headers Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf/protobuf.cu | 2 +- .../cpp/src/protobuf/protobuf_builders.cu | 2 +- .../src/protobuf/protobuf_host_helpers.cuh | 979 ------------------ .../src/protobuf/protobuf_host_helpers.hpp | 310 ++++++ .../cpp/src/protobuf/protobuf_kernels.cuh | 665 ++++++++++++ 5 files changed, 977 insertions(+), 981 deletions(-) delete mode 100644 src/main/cpp/src/protobuf/protobuf_host_helpers.cuh create mode 100644 src/main/cpp/src/protobuf/protobuf_host_helpers.hpp diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index e618201dca..e5912fd4e2 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -15,7 +15,7 @@ */ #include "nvtx_ranges.hpp" -#include "protobuf/protobuf_host_helpers.cuh" +#include "protobuf/protobuf_kernels.cuh" #include diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index cb9dd7b54d..0cf68dae2d 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "protobuf/protobuf_host_helpers.cuh" +#include "protobuf/protobuf_kernels.cuh" #include #include diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh b/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh deleted file mode 100644 index 0f204c4d84..0000000000 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.cuh +++ /dev/null @@ -1,979 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "protobuf/protobuf_kernels.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace spark_rapids_jni::protobuf::detail { - -// ============================================================================ -// Field number lookup table helpers -// ============================================================================ - -/** - * Build a host-side direct-mapped lookup table: field_number -> index. - * Returns an empty vector if the max field number exceeds the threshold. - * - * @tparam GetFieldNumber callable (int i) -> int returning the field number for index i - */ -template -inline std::vector build_lookup_table(int num_entries, GetFieldNumber get_fn) -{ - int max_fn = 0; - for (int i = 0; i < num_entries; i++) { - max_fn = std::max(max_fn, get_fn(i)); - } - if (max_fn > FIELD_LOOKUP_TABLE_MAX) { return {}; } - std::vector table(max_fn + 1, -1); - for (int i = 0; i < num_entries; i++) { - table[get_fn(i)] = i; - } - return table; -} - -inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, - int const* field_indices, - int num_indices) -{ - return build_lookup_table(num_indices, - [&](int i) { return schema[field_indices[i]].field_number; }); -} - -inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) -{ - return build_lookup_table(num_fields, [&](int i) { return descs[i].field_number; }); -} - -template -inline std::pair make_null_mask_from_valid( - rmm::device_uvector const& valid, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto begin = thrust::make_counting_iterator(0); - auto end = begin + valid.size(); - auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { - return static_cast(ptr[i]); - }; - return cudf::detail::valid_if(begin, end, pred, stream, mr); -} - -template -std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, - int num_rows, - LaunchFn&& launch_extract, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - rmm::device_uvector out(num_rows, stream, mr); - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - launch_extract(out.data(), valid.data()); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); -} - -template -// Shared integer extractor for INT32/INT64/UINT32/UINT64 decode paths. -inline void extract_integer_into_buffers(uint8_t const* message_data, - LocationProvider const& loc_provider, - int num_rows, - int blocks, - int threads, - bool has_default, - int64_t default_value, - int encoding, - bool enable_zigzag, - T* out_ptr, - bool* valid_ptr, - int* error_ptr, - rmm::cuda_stream_view stream) -{ - if (enable_zigzag && encoding == encoding_value(proto_encoding::ZIGZAG)) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - default_value); - } else if (encoding == encoding_value(proto_encoding::FIXED)) { - if constexpr (sizeof(T) == 4) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - static_cast(default_value)); - } else { - static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - static_cast(default_value)); - } - } else { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_rows, - out_ptr, - valid_ptr, - error_ptr, - has_default, - default_value); - } -} - -template -// Builds a scalar column for integer-like protobuf fields. -std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, - uint8_t const* message_data, - LocationProvider const& loc_provider, - int num_rows, - int blocks, - int threads, - rmm::device_uvector& d_error, - bool has_default, - int64_t default_value, - int encoding, - bool enable_zigzag, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - return extract_and_build_scalar_column( - dt, - num_rows, - [&](T* out_ptr, bool* valid_ptr) { - extract_integer_into_buffers(message_data, - loc_provider, - num_rows, - blocks, - threads, - has_default, - default_value, - encoding, - enable_zigzag, - out_ptr, - valid_ptr, - d_error.data(), - stream); - }, - stream, - mr); -} - -struct extract_strided_count { - repeated_field_info const* info; - int field_idx; - int num_fields; - - __device__ int32_t operator()(int row) const - { - return info[flat_index(static_cast(row), - static_cast(num_fields), - static_cast(field_idx))] - .count; - } -}; - -/** - * Find all child field indices for a given parent index in the schema. - * This is a commonly used pattern throughout the codebase. - * - * @param schema The schema vector (either nested_field_descriptor or - * device_nested_field_descriptor) - * @param num_fields Number of fields in the schema - * @param parent_idx The parent index to search for - * @return Vector of child field indices - */ -template -std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) -{ - std::vector child_indices; - for (int i = 0; i < num_fields; i++) { - if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } - } - return child_indices; -} - -// Forward declarations needed by make_empty_struct_column_with_schema -std::unique_ptr make_empty_column_safe(cudf::data_type dtype, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -std::unique_ptr make_empty_list_column(std::unique_ptr element_col, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -/** - * Extract output type from either nested_field_descriptor (.output_type is cudf::type_id) - * or device_nested_field_descriptor (.output_type_id is int). - */ -template -inline cudf::type_id get_output_type_id(FieldT const& field) -{ - if constexpr (std::is_same_v) { - return static_cast(field.output_type_id); - } else { - return field.output_type; - } -} - -template -std::unique_ptr make_empty_struct_column_with_schema( - SchemaT const& schema, - int parent_idx, - int num_fields, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); - - std::vector> children; - for (int child_idx : child_indices) { - auto child_type = cudf::data_type{get_output_type_id(schema[child_idx])}; - - std::unique_ptr child_col; - if (child_type.id() == cudf::type_id::STRUCT) { - child_col = make_empty_struct_column_with_schema(schema, child_idx, num_fields, stream, mr); - } else { - child_col = make_empty_column_safe(child_type, stream, mr); - } - - if (schema[child_idx].is_repeated) { - child_col = make_empty_list_column(std::move(child_col), stream, mr); - } - - children.push_back(std::move(child_col)); - } - - return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); -} - -// ============================================================================ -// Host wrapper declarations for kernel launches -// ============================================================================ - -void launch_scan_all_fields(cudf::column_device_view const& d_in, - field_descriptor const* field_descs, - int num_fields, - int const* field_lookup, - int field_lookup_size, - field_location* locations, - int* error_flag, - bool* row_has_invalid_data, - int num_rows, - rmm::cuda_stream_view stream); - -void launch_count_repeated_fields(cudf::column_device_view const& d_in, - device_nested_field_descriptor const* schema, - int num_fields, - int depth_level, - repeated_field_info* repeated_info, - int num_repeated_fields, - int const* repeated_field_indices, - field_location* nested_locations, - int num_nested_fields, - int const* nested_field_indices, - int* error_flag, - int const* fn_to_rep_idx, - int fn_to_rep_size, - int const* fn_to_nested_idx, - int fn_to_nested_size, - int num_rows, - rmm::cuda_stream_view stream); - -void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, - repeated_field_scan_desc const* scan_descs, - int num_scan_fields, - int* error_flag, - int const* fn_to_desc_idx, - int fn_to_desc_size, - int num_rows, - rmm::cuda_stream_view stream); - -void launch_extract_strided_locations(field_location const* nested_locations, - int field_idx, - int num_fields, - field_location* parent_locs, - int num_rows, - rmm::cuda_stream_view stream); - -void launch_scan_nested_message_fields(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* parent_row_offsets, - cudf::size_type parent_base_offset, - field_location const* parent_locations, - int num_parent_rows, - field_descriptor const* field_descs, - int num_fields, - field_location* output_locations, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_scan_repeated_message_children(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* msg_row_offsets, - field_location const* msg_locs, - int num_occurrences, - field_descriptor const* child_descs, - int num_child_fields, - field_location* child_locs, - int* error_flag, - int const* child_lookup, - int child_lookup_size, - rmm::cuda_stream_view stream); - -void launch_count_repeated_in_nested(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int num_fields, - repeated_field_info* repeated_info, - int num_repeated, - int const* repeated_indices, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_scan_repeated_in_nested(uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_rows, - device_nested_field_descriptor const* schema, - int32_t const* occ_prefix_sums, - int const* repeated_indices, - repeated_occurrence* occurrences, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_compute_nested_struct_locations(field_location const* child_locs, - field_location const* msg_locs, - cudf::size_type const* msg_row_offsets, - int child_idx, - int num_child_fields, - field_location* nested_locs, - cudf::size_type* nested_row_offsets, - int total_count, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_compute_grandchild_parent_locations(field_location const* parent_locs, - field_location const* child_locs, - int child_idx, - int num_child_fields, - field_location* gc_parent_abs, - int num_rows, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence const* occurrences, - cudf::size_type const* row_list_offsets, - field_location const* parent_locations, - cudf::size_type* virtual_row_offsets, - field_location* virtual_parent_locs, - int total_count, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_compute_msg_locations_from_occurrences(repeated_occurrence const* occurrences, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - field_location* msg_locs, - cudf::size_type* msg_row_offsets, - int total_count, - int* error_flag, - rmm::cuda_stream_view stream); - -void launch_validate_enum_values(int32_t const* values, - bool* valid, - bool* row_has_invalid_enum, - int32_t const* valid_enum_values, - int num_valid_values, - int num_rows, - rmm::cuda_stream_view stream); - -void launch_compute_enum_string_lengths(int32_t const* values, - bool const* valid, - int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, - int num_valid_values, - int32_t* lengths, - int num_rows, - rmm::cuda_stream_view stream); - -void launch_copy_enum_string_chars(int32_t const* values, - bool const* valid, - int32_t const* valid_enum_values, - int32_t const* enum_name_offsets, - uint8_t const* enum_name_chars, - int num_valid_values, - int32_t const* output_offsets, - char* out_chars, - int num_rows, - rmm::cuda_stream_view stream); - -void maybe_check_required_fields(field_location const* locations, - std::vector const& field_indices, - std::vector const& schema, - int num_rows, - cudf::bitmask_type const* input_null_mask, - cudf::size_type input_offset, - field_location const* parent_locs, - bool* row_force_null, - int32_t const* top_row_indices, - int* error_flag, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, - rmm::device_uvector& row_invalid, - int num_items, - int32_t const* top_row_indices, - bool propagate_to_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -void validate_enum_and_propagate_rows(rmm::device_uvector const& values, - rmm::device_uvector& valid, - cudf::detail::host_vector const& valid_enums, - rmm::device_uvector& row_invalid, - int num_items, - int32_t const* top_row_indices, - bool propagate_to_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -// ============================================================================ -// Forward declarations of builder/utility functions -// ============================================================================ - -std::unique_ptr make_null_column(cudf::data_type dtype, - cudf::size_type num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -std::unique_ptr make_null_list_column_with_child( - std::unique_ptr child_col, - cudf::size_type num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -std::unique_ptr build_enum_string_column( - rmm::device_uvector& enum_values, - rmm::device_uvector& valid, - cudf::detail::host_vector const& valid_enums, - std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_force_null, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int32_t const* top_row_indices = nullptr, - bool propagate_invalid_rows = true); - -// Complex builder forward declarations -std::unique_ptr build_repeated_enum_string_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - cudf::detail::host_vector const& valid_enums, - std::vector> const& enum_name_bytes, - rmm::device_uvector& d_row_force_null, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -std::unique_ptr build_repeated_string_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - device_nested_field_descriptor const& field_desc, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - bool is_bytes, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -std::unique_ptr build_nested_struct_column( - uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - rmm::device_uvector const& d_parent_locs, - std::vector const& child_field_indices, - std::vector const& schema, - int num_fields, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_force_null, - rmm::device_uvector& d_error, - int num_rows, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int32_t const* top_row_indices, - int depth, - bool propagate_invalid_rows = true); - -std::unique_ptr build_repeated_child_list_column( - uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* row_offsets, - cudf::size_type base_offset, - field_location const* parent_locs, - int num_parent_rows, - int child_schema_idx, - std::vector const& schema, - int num_fields, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_force_null, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int32_t const* top_row_indices, - int depth, - bool propagate_invalid_rows = true); - -std::unique_ptr build_repeated_struct_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type message_data_size, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - device_nested_field_descriptor const& field_desc, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - std::vector const& h_device_schema, - std::vector const& child_field_indices, - std::vector const& default_ints, - std::vector const& default_floats, - std::vector const& default_bools, - std::vector> const& default_strings, - std::vector const& schema, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_force_null, - rmm::device_uvector& d_error_top, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr); - -template -inline std::unique_ptr extract_and_build_string_or_bytes_column( - bool as_bytes, - uint8_t const* message_data, - int num_rows, - LengthProvider const& length_provider, - CopyProvider const& copy_provider, - ValidityFn validity_fn, - bool has_default, - cudf::detail::host_vector const& default_bytes, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; - rmm::device_uvector d_default(0, stream, mr); - if (has_default && def_len > 0) { - d_default = cudf::detail::make_device_uvector_async( - default_bytes, stream, rmm::mr::get_current_device_resource()); - } - - rmm::device_uvector lengths(num_rows, stream, mr); - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((num_rows + threads - 1u) / threads); - extract_lengths_kernel<<>>( - length_provider, num_rows, lengths.data(), has_default, def_len); - - auto [offsets_col, total_size] = - cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); - - rmm::device_uvector chars(total_size, stream, mr); - if (total_size > 0) { - copy_varlen_data_kernel - <<>>(message_data, - copy_provider, - num_rows, - offsets_col->view().data(), - chars.data(), - d_error.data(), - has_default, - d_default.data(), - def_len); - } - - rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); - thrust::transform(rmm::exec_policy_nosync(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - valid.data(), - validity_fn); - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - if (as_bytes) { - auto bytes_child = - std::make_unique(cudf::data_type{cudf::type_id::UINT8}, - total_size, - rmm::device_buffer(chars.data(), total_size, stream, mr), - rmm::device_buffer{}, - 0); - return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); - } - - return cudf::make_strings_column( - num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); -} - -template -inline std::unique_ptr extract_typed_column( - cudf::data_type dt, - int encoding, - uint8_t const* message_data, - LocationProvider const& loc_provider, - int num_items, - int blocks, - int threads_per_block, - bool has_default, - int64_t default_int, - double default_float, - bool default_bool, - cudf::detail::host_vector const& default_string, - int schema_idx, - std::vector> const& enum_valid_values, - std::vector>> const& enum_names, - rmm::device_uvector& d_row_force_null, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr, - int32_t const* top_row_indices = nullptr, - bool propagate_invalid_rows = true) -{ - switch (dt.id()) { - case cudf::type_id::BOOL8: { - int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; - return extract_and_build_scalar_column( - dt, - num_items, - [&](uint8_t* out_ptr, bool* valid_ptr) { - extract_varint_kernel - <<>>(message_data, - loc_provider, - num_items, - out_ptr, - valid_ptr, - d_error.data(), - has_default, - def_val); - }, - stream, - mr); - } - case cudf::type_id::INT32: { - rmm::device_uvector out(num_items, stream, mr); - rmm::device_uvector valid((num_items > 0 ? num_items : 1), stream, mr); - extract_integer_into_buffers(message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - has_default, - default_int, - encoding, - true, - out.data(), - valid.data(), - d_error.data(), - stream); - if (schema_idx < static_cast(enum_valid_values.size())) { - auto const& valid_enums = enum_valid_values[schema_idx]; - if (!valid_enums.empty()) { - validate_enum_and_propagate_rows(out, - valid, - valid_enums, - d_row_force_null, - num_items, - top_row_indices, - propagate_invalid_rows, - stream, - mr); - } - } - auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); - return std::make_unique( - dt, num_items, out.release(), std::move(mask), null_count); - } - case cudf::type_id::UINT32: - return extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - d_error, - has_default, - default_int, - encoding, - false, - stream, - mr); - case cudf::type_id::INT64: - return extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - d_error, - has_default, - default_int, - encoding, - true, - stream, - mr); - case cudf::type_id::UINT64: - return extract_and_build_integer_column(dt, - message_data, - loc_provider, - num_items, - blocks, - threads_per_block, - d_error, - has_default, - default_int, - encoding, - false, - stream, - mr); - case cudf::type_id::FLOAT32: { - float def_float_val = has_default ? static_cast(default_float) : 0.0f; - return extract_and_build_scalar_column( - dt, - num_items, - [&](float* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_items, - out_ptr, - valid_ptr, - d_error.data(), - has_default, - def_float_val); - }, - stream, - mr); - } - case cudf::type_id::FLOAT64: { - double def_double = has_default ? default_float : 0.0; - return extract_and_build_scalar_column( - dt, - num_items, - [&](double* out_ptr, bool* valid_ptr) { - extract_fixed_kernel - <<>>(message_data, - loc_provider, - num_items, - out_ptr, - valid_ptr, - d_error.data(), - has_default, - def_double); - }, - stream, - mr); - } - default: return make_null_column(dt, num_items, stream, mr); - } -} - -template -inline std::unique_ptr build_repeated_scalar_column( - cudf::column_view const& binary_input, - uint8_t const* message_data, - cudf::size_type const* list_offsets, - cudf::size_type base_offset, - device_nested_field_descriptor const& field_desc, - rmm::device_uvector const& d_field_counts, - rmm::device_uvector& d_occurrences, - int total_count, - int num_rows, - rmm::device_uvector& d_error, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref mr) -{ - auto const input_null_count = binary_input.null_count(); - auto const field_type_id = static_cast(field_desc.output_type_id); - - if (total_count == 0) { - // All rows have count=0, but we still need to check input nulls - rmm::device_uvector offsets(num_rows + 1, stream, mr); - thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - offsets.release(), - rmm::device_buffer{}, - 0); - auto elem_type = field_type_id == cudf::type_id::LIST ? cudf::type_id::UINT8 : field_type_id; - auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); - - if (input_null_count > 0) { - // Copy input null mask - only input nulls produce output nulls - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - input_null_count, - std::move(null_mask)); - } else { - // No input nulls, all rows get empty arrays [] - return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); - } - } - - rmm::device_uvector list_offs(num_rows + 1, stream, mr); - thrust::exclusive_scan(rmm::exec_policy_nosync(stream), - d_field_counts.begin(), - d_field_counts.end(), - list_offs.begin(), - 0); - - int32_t total_count_i32 = static_cast(total_count); - thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); - - rmm::device_uvector values(total_count, stream, mr); - - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((total_count + threads - 1u) / threads); - - int encoding = field_desc.encoding; - bool zigzag = (encoding == encoding_value(proto_encoding::ZIGZAG)); - - // For float/double types, always use fixed kernel (they use wire type 32BIT/64BIT) - // For integer types, use fixed kernel only if encoding is - // encoding_value(proto_encoding::FIXED) - constexpr bool is_floating_point = std::is_same_v || std::is_same_v; - bool use_fixed_kernel = is_floating_point || (encoding == encoding_value(proto_encoding::FIXED)); - - repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - if (use_fixed_kernel) { - if constexpr (sizeof(T) == 4) { - extract_fixed_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } else { - extract_fixed_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } - } else if (zigzag) { - extract_varint_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } else { - extract_varint_kernel - <<>>( - message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); - } - - auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, - num_rows + 1, - list_offs.release(), - rmm::device_buffer{}, - 0); - auto child_col = std::make_unique( - cudf::data_type{field_type_id}, total_count, values.release(), rmm::device_buffer{}, 0); - - // Only rows where INPUT is null should produce null output - // Rows with valid input but count=0 should produce empty array [] - if (input_null_count > 0) { - // Copy input null mask - only input nulls produce output nulls - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_lists_column(num_rows, - std::move(offsets_col), - std::move(child_col), - input_null_count, - std::move(null_mask)); - } - - return cudf::make_lists_column( - num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); -} - -} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp new file mode 100644 index 0000000000..21de289158 --- /dev/null +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -0,0 +1,310 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "protobuf/protobuf_types.cuh" + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include + +namespace spark_rapids_jni::protobuf::detail { + +// ============================================================================ +// Field number lookup table helpers +// ============================================================================ + +/** + * Build a host-side direct-mapped lookup table: field_number -> index. + * Returns an empty vector if the max field number exceeds the threshold. + * + * @tparam GetFieldNumber callable (int i) -> int returning the field number for index i + */ +template +inline std::vector build_lookup_table(int num_entries, GetFieldNumber get_fn) +{ + int max_fn = 0; + for (int i = 0; i < num_entries; i++) { + max_fn = std::max(max_fn, get_fn(i)); + } + if (max_fn > FIELD_LOOKUP_TABLE_MAX) { return {}; } + std::vector table(max_fn + 1, -1); + for (int i = 0; i < num_entries; i++) { + table[get_fn(i)] = i; + } + return table; +} + +inline std::vector build_index_lookup_table(nested_field_descriptor const* schema, + int const* field_indices, + int num_indices) +{ + return build_lookup_table(num_indices, + [&](int i) { return schema[field_indices[i]].field_number; }); +} + +inline std::vector build_field_lookup_table(field_descriptor const* descs, int num_fields) +{ + return build_lookup_table(num_fields, [&](int i) { return descs[i].field_number; }); +} + +/** + * Find all child field indices for a given parent index in the schema. + * This is a commonly used pattern throughout the codebase. + * + * @param schema The schema vector (either nested_field_descriptor or + * device_nested_field_descriptor) + * @param num_fields Number of fields in the schema + * @param parent_idx The parent index to search for + * @return Vector of child field indices + */ +template +std::vector find_child_field_indices(SchemaT const& schema, int num_fields, int parent_idx) +{ + std::vector child_indices; + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == parent_idx) { child_indices.push_back(i); } + } + return child_indices; +} + +// Forward declarations needed by make_empty_struct_column_with_schema +std::unique_ptr make_empty_column_safe(cudf::data_type dtype, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_empty_list_column(std::unique_ptr element_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +/** + * Extract output type from either nested_field_descriptor (.output_type is cudf::type_id) + * or device_nested_field_descriptor (.output_type_id is int). + */ +template +inline cudf::type_id get_output_type_id(FieldT const& field) +{ + if constexpr (std::is_same_v) { + return static_cast(field.output_type_id); + } else { + return field.output_type; + } +} + +template +std::unique_ptr make_empty_struct_column_with_schema( + SchemaT const& schema, + int parent_idx, + int num_fields, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto child_indices = find_child_field_indices(schema, num_fields, parent_idx); + + std::vector> children; + for (int child_idx : child_indices) { + auto child_type = cudf::data_type{get_output_type_id(schema[child_idx])}; + + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema(schema, child_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + + if (schema[child_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + + children.push_back(std::move(child_col)); + } + + return cudf::make_structs_column(0, std::move(children), 0, rmm::device_buffer{}, stream, mr); +} + +void maybe_check_required_fields(field_location const* locations, + std::vector const& field_indices, + std::vector const& schema, + int num_rows, + cudf::bitmask_type const* input_null_mask, + cudf::size_type input_offset, + field_location const* parent_locs, + bool* row_force_null, + int32_t const* top_row_indices, + int* error_flag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_invalid, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +void validate_enum_and_propagate_rows(rmm::device_uvector const& values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + rmm::device_uvector& row_invalid, + int num_items, + int32_t const* top_row_indices, + bool propagate_to_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +// ============================================================================ +// Forward declarations of builder/utility functions +// ============================================================================ + +std::unique_ptr make_null_column(cudf::data_type dtype, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr make_null_list_column_with_child( + std::unique_ptr child_col, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true); + +// Complex builder forward declarations +std::unique_ptr build_repeated_enum_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows = true); + +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows = true); + +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& h_device_schema, + std::vector const& child_field_indices, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error_top, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +} // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index 0804ac4e59..2bdb7a76da 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -17,11 +17,32 @@ #pragma once #include "protobuf/protobuf_device_helpers.cuh" +#include "protobuf/protobuf_host_helpers.hpp" #include +#include +#include +#include #include +#include +#include +#include + +#include +#include +#include #include +#include +#include +#include +#include + +#include +#include +#include +#include +#include namespace spark_rapids_jni::protobuf::detail { @@ -409,4 +430,648 @@ CUDF_KERNEL void copy_varlen_data_kernel(uint8_t const* message_data, } // ============================================================================ +// Host wrapper declarations for kernel launches +// ============================================================================ + +void launch_scan_all_fields(cudf::column_device_view const& d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag, + bool* row_has_invalid_data, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_count_repeated_fields(cudf::column_device_view const& d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_extract_strided_locations(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_scan_nested_message_fields(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_scan_repeated_message_children(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag, + int const* child_lookup, + int child_lookup_size, + rmm::cuda_stream_view stream); + +void launch_count_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_scan_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_nested_struct_locations(field_location const* child_locs, + field_location const* msg_locs, + cudf::size_type const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + cudf::size_type* nested_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_grandchild_parent_locations(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_msg_locations_from_occurrences(repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + cudf::size_type* msg_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_validate_enum_values(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_compute_enum_string_lengths(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_copy_enum_string_chars(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows, + rmm::cuda_stream_view stream); + +// ============================================================================ +// Host-side template helpers that launch CUDA kernels +// ============================================================================ + +template +inline std::pair make_null_mask_from_valid( + rmm::device_uvector const& valid, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto begin = thrust::make_counting_iterator(0); + auto end = begin + valid.size(); + auto pred = [ptr = valid.data()] __device__(cudf::size_type i) { + return static_cast(ptr[i]); + }; + return cudf::detail::valid_if(begin, end, pred, stream, mr); +} + +template +std::unique_ptr extract_and_build_scalar_column(cudf::data_type dt, + int num_rows, + LaunchFn&& launch_extract, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + launch_extract(out.data(), valid.data()); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique(dt, num_rows, out.release(), std::move(mask), null_count); +} + +template +inline void extract_integer_into_buffers(uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + T* out_ptr, + bool* valid_ptr, + int* error_ptr, + rmm::cuda_stream_view stream) +{ + if (enable_zigzag && encoding == encoding_value(proto_encoding::ZIGZAG)) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } else if (encoding == encoding_value(proto_encoding::FIXED)) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } else { + static_assert(sizeof(T) == 8, "extract_integer_into_buffers only supports 32/64-bit"); + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + static_cast(default_value)); + } + } else { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out_ptr, + valid_ptr, + error_ptr, + has_default, + default_value); + } +} + +template +std::unique_ptr extract_and_build_integer_column(cudf::data_type dt, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_rows, + int blocks, + int threads, + rmm::device_uvector& d_error, + bool has_default, + int64_t default_value, + int encoding, + bool enable_zigzag, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + return extract_and_build_scalar_column( + dt, + num_rows, + [&](T* out_ptr, bool* valid_ptr) { + extract_integer_into_buffers(message_data, + loc_provider, + num_rows, + blocks, + threads, + has_default, + default_value, + encoding, + enable_zigzag, + out_ptr, + valid_ptr, + d_error.data(), + stream); + }, + stream, + mr); +} + +struct extract_strided_count { + repeated_field_info const* info; + int field_idx; + int num_fields; + + __device__ int32_t operator()(int row) const + { + return info[flat_index(static_cast(row), + static_cast(num_fields), + static_cast(field_idx))] + .count; + } +}; + +template +inline std::unique_ptr extract_and_build_string_or_bytes_column( + bool as_bytes, + uint8_t const* message_data, + int num_rows, + LengthProvider const& length_provider, + CopyProvider const& copy_provider, + ValidityFn validity_fn, + bool has_default, + cudf::detail::host_vector const& default_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + int32_t def_len = has_default ? static_cast(default_bytes.size()) : 0; + rmm::device_uvector d_default(0, stream, mr); + if (has_default && def_len > 0) { + d_default = cudf::detail::make_device_uvector_async( + default_bytes, stream, rmm::mr::get_current_device_resource()); + } + + rmm::device_uvector lengths(num_rows, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + extract_lengths_kernel<<>>( + length_provider, num_rows, lengths.data(), has_default, def_len); + + auto [offsets_col, total_size] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_size, stream, mr); + if (total_size > 0) { + copy_varlen_data_kernel + <<>>(message_data, + copy_provider, + num_rows, + offsets_col->view().data(), + chars.data(), + d_error.data(), + has_default, + d_default.data(), + def_len); + } + + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + valid.data(), + validity_fn); + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_size, + rmm::device_buffer(chars.data(), total_size, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); + } + + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +template +inline std::unique_ptr extract_typed_column( + cudf::data_type dt, + int encoding, + uint8_t const* message_data, + LocationProvider const& loc_provider, + int num_items, + int blocks, + int threads_per_block, + bool has_default, + int64_t default_int, + double default_float, + bool default_bool, + cudf::detail::host_vector const& default_string, + int schema_idx, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices = nullptr, + bool propagate_invalid_rows = true) +{ + switch (dt.id()) { + case cudf::type_id::BOOL8: { + int64_t def_val = has_default ? (default_bool ? 1 : 0) : 0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](uint8_t* out_ptr, bool* valid_ptr) { + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_val); + }, + stream, + mr); + } + case cudf::type_id::INT32: { + rmm::device_uvector out(num_items, stream, mr); + rmm::device_uvector valid((num_items > 0 ? num_items : 1), stream, mr); + extract_integer_into_buffers(message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + has_default, + default_int, + encoding, + true, + out.data(), + valid.data(), + d_error.data(), + stream); + if (schema_idx < static_cast(enum_valid_values.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + if (!valid_enums.empty()) { + validate_enum_and_propagate_rows(out, + valid, + valid_enums, + d_row_force_null, + num_items, + top_row_indices, + propagate_invalid_rows, + stream, + mr); + } + } + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return std::make_unique( + dt, num_items, out.release(), std::move(mask), null_count); + } + case cudf::type_id::UINT32: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::INT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + true, + stream, + mr); + case cudf::type_id::UINT64: + return extract_and_build_integer_column(dt, + message_data, + loc_provider, + num_items, + blocks, + threads_per_block, + d_error, + has_default, + default_int, + encoding, + false, + stream, + mr); + case cudf::type_id::FLOAT32: { + float def_float_val = has_default ? static_cast(default_float) : 0.0f; + return extract_and_build_scalar_column( + dt, + num_items, + [&](float* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_float_val); + }, + stream, + mr); + } + case cudf::type_id::FLOAT64: { + double def_double = has_default ? default_float : 0.0; + return extract_and_build_scalar_column( + dt, + num_items, + [&](double* out_ptr, bool* valid_ptr) { + extract_fixed_kernel + <<>>(message_data, + loc_provider, + num_items, + out_ptr, + valid_ptr, + d_error.data(), + has_default, + def_double); + }, + stream, + mr); + } + default: return make_null_column(dt, num_items, stream, mr); + } +} + +template +inline std::unique_ptr build_repeated_scalar_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + auto const field_type_id = static_cast(field_desc.output_type_id); + + if (total_count == 0) { + // All rows have count=0, but we still need to check input nulls + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto elem_type = field_type_id == cudf::type_id::LIST ? cudf::type_id::UINT8 : field_type_id; + auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } else { + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); + + int32_t total_count_i32 = static_cast(total_count); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); + + rmm::device_uvector values(total_count, stream, mr); + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + + int encoding = field_desc.encoding; + bool zigzag = (encoding == encoding_value(proto_encoding::ZIGZAG)); + + constexpr bool is_floating_point = std::is_same_v || std::is_same_v; + bool use_fixed_kernel = is_floating_point || (encoding == encoding_value(proto_encoding::FIXED)); + + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + if (use_fixed_kernel) { + if constexpr (sizeof(T) == 4) { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_fixed_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + } else if (zigzag) { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } else { + extract_varint_kernel + <<>>( + message_data, loc_provider, total_count, values.data(), nullptr, d_error.data()); + } + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto child_col = std::make_unique( + cudf::data_type{field_type_id}, total_count, values.release(), rmm::device_buffer{}, 0); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); +} + } // namespace spark_rapids_jni::protobuf::detail From 8ea16b5c1224c858115a68c8e027ee650172c942 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 7 Apr 2026 10:57:55 +0800 Subject: [PATCH 103/107] Remove stale mr parameter from call sites maybe_check_required_fields, propagate_invalid_enum_flags_to_rows, and validate_enum_and_propagate_rows no longer take a mr parameter (changed in part0 review). Update all callers. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/main/cpp/src/protobuf/protobuf.cu | 6 ++---- src/main/cpp/src/protobuf/protobuf_builders.cu | 18 ++++++------------ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 5fa017285f..50f19303c1 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -599,8 +599,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& track_permissive_null_rows ? d_row_force_null.data() : nullptr, nullptr, d_error.data(), - stream, - mr); + stream); // Batched scalar extraction: group non-special fixed-width fields by extraction // category and extract all fields of each category with a single 2D kernel launch. @@ -923,8 +922,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& track_permissive_null_rows ? d_row_force_null.data() : nullptr, nullptr, d_error.data(), - stream, - mr); + stream); // Process repeated fields (three-phase: offsets → combined scan → build columns) if (num_repeated > 0) { diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 86bd2a6232..1734a41704 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -329,8 +329,7 @@ std::unique_ptr build_enum_string_column( num_rows, top_row_indices, propagate_invalid_rows, - stream, - mr); + stream); return build_enum_string_values_column(enum_values, valid, lookup, num_rows, stream, mr); } @@ -390,8 +389,7 @@ std::unique_ptr build_repeated_msg_child_enum_string_column( total_count, top_row_indices, propagate_invalid_rows, - stream, - mr); + stream); return build_enum_string_values_column(enum_values, valid, lookup, total_count, stream, mr); } @@ -456,8 +454,7 @@ std::unique_ptr build_repeated_enum_string_column( total_count, d_top_row_indices.data(), true, - stream, - mr); + stream); auto child_col = build_enum_string_values_column(enum_ints, elem_valid, lookup, total_count, stream, mr); @@ -813,8 +810,7 @@ std::unique_ptr build_repeated_struct_column( d_row_force_null.size() > 0 ? d_row_force_null.data() : nullptr, d_top_row_indices.data(), d_error.data(), - stream, - mr); + stream); // Note: We no longer need to copy child_locs to host because: // 1. All scalar extraction kernels access d_child_locs directly on device @@ -1127,8 +1123,7 @@ std::unique_ptr build_nested_struct_column( d_row_force_null.size() > 0 ? d_row_force_null.data() : nullptr, top_row_indices, d_error.data(), - stream, - mr); + stream); std::vector> struct_children; for (int ci = 0; ci < num_child_fields; ci++) { @@ -1583,8 +1578,7 @@ std::unique_ptr build_repeated_child_list_column( total_rep_count, d_rep_top_row_indices.data(), propagate_invalid_rows, - stream, - mr); + stream); child_values = build_enum_string_values_column(enum_values, valid, lookup, total_rep_count, stream, mr); } else { From 2e91499b520fa1cdb8b1426bde489a6abd693718 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 7 Apr 2026 11:21:39 +0800 Subject: [PATCH 104/107] Apply part0 brace style to dev-only code Remove braces from single-statement if bodies to match the convention established during part0 review. Co-Authored-By: Claude Opus 4.6 (1M context) --- src/main/cpp/src/protobuf/protobuf.cu | 6 +- src/main/cpp/src/protobuf/protobuf_kernels.cu | 110 +++++++++--------- 2 files changed, 58 insertions(+), 58 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 50f19303c1..ab34ceccf2 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -627,7 +627,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& bool zz = (enc == proto_encoding::ZIGZAG); // STRING, LIST, and enum-as-string go to per-field path - if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) { continue; } + if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) continue; bool is_fixed = (enc == proto_encoding::FIXED); @@ -672,7 +672,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& // Helper: batch-extract one group using a 2D kernel, then build columns. auto do_batch = [&](std::vector const& idxs, auto kernel_launcher) { int nf = static_cast(idxs.size()); - if (nf == 0) { return; } + if (nf == 0) return; std::vector> bufs; bufs.reserve(nf); @@ -1348,7 +1348,7 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& if (h_error == ERR_SCHEMA_TOO_LARGE || h_error == ERR_REPEATED_COUNT_MISMATCH) { throw cudf::logic_error(error_message(h_error)); } - if (fail_on_errors && h_error != 0) { throw cudf::logic_error(error_message(h_error)); } + if (fail_on_errors && h_error != 0) throw cudf::logic_error(error_message(h_error)); // Build final struct with PERMISSIVE mode null mask for invalid enums cudf::size_type struct_null_count = 0; diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index cbf34199ee..2b771302b3 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -372,7 +372,7 @@ CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_i { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) { return; } + if (row >= in.size()) return; // Initialize repeated counts to 0 for (int f = 0; f < num_repeated_fields; f++) { @@ -388,21 +388,21 @@ CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_i -1, 0}; } - if (in.nullable() && in.is_null(row)) { return; } + if (in.nullable() && in.is_null(row)) return; auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); int32_t start = in.offset_at(row) - base; int32_t end = in.offset_at(row + 1) - base; - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + if (!check_message_bounds(start, end, child.size(), error_flag)) return; uint8_t const* cur = bytes + start; uint8_t const* msg_end = bytes + end; while (cur < msg_end) { proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + if (!decode_tag(cur, msg_end, tag, error_flag)) return; int fn = tag.field_number; int wt = tag.wire_type; @@ -481,14 +481,14 @@ CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_i if (i >= 0) { int schema_idx = nested_field_indices[i]; if (schema[schema_idx].depth == depth_level) { - if (!handle_nested(i)) { return; } + if (!handle_nested(i)) return; } } } else { for (int i = 0; i < num_nested_fields; i++) { int schema_idx = nested_field_indices[i]; if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { - if (!handle_nested(i)) { return; } + if (!handle_nested(i)) return; } } } @@ -516,16 +516,16 @@ CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view c { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); cudf::detail::lists_column_device_view in{d_in}; - if (row >= in.size()) { return; } + if (row >= in.size()) return; - if (in.nullable() && in.is_null(row)) { return; } + if (in.nullable() && in.is_null(row)) return; auto const base = in.offset_at(0); auto const child = in.get_sliced_child(); auto const* bytes = reinterpret_cast(child.data()); int32_t start = in.offset_at(row) - base; int32_t end = in.offset_at(row + 1) - base; - if (!check_message_bounds(start, end, child.size(), error_flag)) { return; } + if (!check_message_bounds(start, end, child.size(), error_flag)) return; uint8_t const* cur = bytes + start; uint8_t const* msg_end = bytes + end; @@ -542,7 +542,7 @@ CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view c while (cur < msg_end) { proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + if (!decode_tag(cur, msg_end, tag, error_flag)) return; int fn = tag.field_number; int wt = tag.wire_type; @@ -569,12 +569,12 @@ CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view c if (fn_to_desc_idx != nullptr && fn > 0 && fn < fn_to_desc_size) { int f = fn_to_desc_idx[fn]; if (f >= 0 && f < num_scan_fields) { - if (!try_scan(f)) { return; } + if (!try_scan(f)) return; } } else { for (int f = 0; f < num_scan_fields; f++) { if (scan_descs[f].field_number == fn) { - if (!try_scan(f)) { return; } + if (!try_scan(f)) return; } } } @@ -620,7 +620,7 @@ CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_parent_rows) { return; } + if (row >= num_parent_rows) return; for (int f = 0; f < num_fields; f++) { output_locations[flat_index( @@ -628,7 +628,7 @@ CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, } auto const& parent_loc = parent_locations[row]; - if (parent_loc.offset < 0) { return; } + if (parent_loc.offset < 0) return; auto parent_row_start = parent_row_offsets[row] - parent_base_offset; int64_t nested_start_off = static_cast(parent_row_start) + parent_loc.offset; @@ -644,7 +644,7 @@ CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, while (cur < nested_end) { proto_tag tag; - if (!decode_tag(cur, nested_end, tag, error_flag)) { return; } + if (!decode_tag(cur, nested_end, tag, error_flag)) return; int fn = tag.field_number; int wt = tag.wire_type; @@ -724,7 +724,7 @@ CUDF_KERNEL void scan_repeated_message_children_kernel( int child_lookup_size) { auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (occ_idx >= num_occurrences) { return; } + if (occ_idx >= num_occurrences) return; // Initialize child locations to not found for (int f = 0; f < num_child_fields; f++) { @@ -734,7 +734,7 @@ CUDF_KERNEL void scan_repeated_message_children_kernel( } auto const& msg_loc = msg_locs[occ_idx]; - if (msg_loc.offset < 0) { return; } + if (msg_loc.offset < 0) return; cudf::size_type row_offset = msg_row_offsets[occ_idx]; int64_t msg_start_off = static_cast(row_offset) + msg_loc.offset; @@ -750,7 +750,7 @@ CUDF_KERNEL void scan_repeated_message_children_kernel( while (cur < msg_end) { proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + if (!decode_tag(cur, msg_end, tag, error_flag)) return; int fn = tag.field_number; int wt = tag.wire_type; @@ -850,7 +850,7 @@ CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) { return; } + if (row >= num_rows) return; // Initialize counts for (int ri = 0; ri < num_repeated; ri++) { @@ -860,7 +860,7 @@ CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, } auto const& parent_loc = parent_locs[row]; - if (parent_loc.offset < 0) { return; } + if (parent_loc.offset < 0) return; cudf::size_type row_off; row_off = row_offsets[row] - base_offset; @@ -878,7 +878,7 @@ CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, while (cur < msg_end) { proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + if (!decode_tag(cur, msg_end, tag, error_flag)) return; int fn = tag.field_number; int wt = tag.wire_type; @@ -929,13 +929,13 @@ CUDF_KERNEL void scan_repeated_in_nested_kernel(uint8_t const* message_data, int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) { return; } + if (row >= num_rows) return; int write_idx = occ_prefix_sums[row]; int write_end = occ_prefix_sums[row + 1]; auto const& parent_loc = parent_locs[row]; if (parent_loc.offset < 0) { - if (write_idx != write_end) { set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); } + if (write_idx != write_end) set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); return; } @@ -956,7 +956,7 @@ CUDF_KERNEL void scan_repeated_in_nested_kernel(uint8_t const* message_data, while (cur < msg_end) { proto_tag tag; - if (!decode_tag(cur, msg_end, tag, error_flag)) { return; } + if (!decode_tag(cur, msg_end, tag, error_flag)) return; int fn = tag.field_number; int wt = tag.wire_type; @@ -983,7 +983,7 @@ CUDF_KERNEL void scan_repeated_in_nested_kernel(uint8_t const* message_data, cur = next; } - if (write_idx != write_end) { set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); } + if (write_idx != write_end) set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); } /** @@ -1003,7 +1003,7 @@ CUDF_KERNEL void compute_nested_struct_locations_kernel( int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) { return; } + if (idx >= total_count) return; nested_locs[idx] = child_locs[flat_index(static_cast(idx), static_cast(num_child_fields), @@ -1034,7 +1034,7 @@ CUDF_KERNEL void compute_grandchild_parent_locations_kernel( int* error_flag) { int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= num_rows) { return; } + if (row >= num_rows) return; auto const& parent_loc = parent_locs[row]; auto const& child_loc = child_locs[flat_index(static_cast(row), @@ -1072,7 +1072,7 @@ CUDF_KERNEL void compute_virtual_parents_for_nested_repeated_kernel( int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) { return; } + if (idx >= total_count) return; auto const& occ = occurrences[idx]; auto const& ploc = parent_locations[occ.row_idx]; @@ -1110,7 +1110,7 @@ CUDF_KERNEL void compute_msg_locations_from_occurrences_kernel( int* error_flag) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - if (idx >= total_count) { return; } + if (idx >= total_count) return; auto const& occ = occurrences[idx]; auto row_offset = static_cast(list_offsets[occ.row_idx]) - base_offset; @@ -1136,7 +1136,7 @@ CUDF_KERNEL void extract_strided_locations_kernel(field_location const* nested_l int num_rows) { int row = blockIdx.x * blockDim.x + threadIdx.x; - if (row >= num_rows) { return; } + if (row >= num_rows) return; parent_locs[row] = nested_locations[flat_index( static_cast(row), static_cast(num_fields), static_cast(field_idx))]; } @@ -1162,11 +1162,11 @@ CUDF_KERNEL void check_required_fields_kernel( int* error_flag) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) { return; } + if (row >= num_rows) return; if (input_null_mask != nullptr && !cudf::bit_is_set(input_null_mask, row + input_offset)) { return; } - if (parent_locs != nullptr && parent_locs[row].offset < 0) { return; } + if (parent_locs != nullptr && parent_locs[row].offset < 0) return; for (int f = 0; f < num_fields; f++) { if (is_required[f] != 0 && locations[flat_index(static_cast(row), @@ -1208,10 +1208,10 @@ CUDF_KERNEL void validate_enum_values_kernel( int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) { return; } + if (row >= num_rows) return; // Skip if already invalid (field was missing) - missing field is not an enum error - if (!valid[row]) { return; } + if (!valid[row]) return; int32_t val = values[row]; @@ -1256,7 +1256,7 @@ CUDF_KERNEL void compute_enum_string_lengths_kernel( int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) { return; } + if (row >= num_rows) return; if (!valid[row]) { lengths[row] = 0; @@ -1298,8 +1298,8 @@ CUDF_KERNEL void copy_enum_string_chars_kernel( int num_rows) { auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); - if (row >= num_rows) { return; } - if (!valid[row]) { return; } + if (row >= num_rows) return; + if (!valid[row]) return; int32_t val = values[row]; int left = 0; @@ -1346,7 +1346,7 @@ void launch_scan_all_fields(cudf::column_device_view const& d_in, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_all_fields_kernel<<>>(d_in, field_descs, @@ -1376,7 +1376,7 @@ void launch_count_repeated_fields(cudf::column_device_view const& d_in, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); count_repeated_fields_kernel<<>>( d_in, @@ -1405,7 +1405,7 @@ void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_all_repeated_occurrences_kernel<<>>( d_in, scan_descs, num_scan_fields, error_flag, fn_to_desc_idx, fn_to_desc_size); @@ -1418,7 +1418,7 @@ void launch_extract_strided_locations(field_location const* nested_locations, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); extract_strided_locations_kernel<<>>( nested_locations, field_idx, num_fields, parent_locs, num_rows); @@ -1436,7 +1436,7 @@ void launch_scan_nested_message_fields(uint8_t const* message_data, int* error_flag, rmm::cuda_stream_view stream) { - if (num_parent_rows == 0) { return; } + if (num_parent_rows == 0) return; auto const blocks = static_cast((num_parent_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_nested_message_fields_kernel<<>>( @@ -1465,7 +1465,7 @@ void launch_scan_repeated_message_children(uint8_t const* message_data, int child_lookup_size, rmm::cuda_stream_view stream) { - if (num_occurrences == 0) { return; } + if (num_occurrences == 0) return; auto const blocks = static_cast((num_occurrences + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_repeated_message_children_kernel<<>>( @@ -1496,7 +1496,7 @@ void launch_count_repeated_in_nested(uint8_t const* message_data, int* error_flag, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); count_repeated_in_nested_kernel<<>>( message_data, @@ -1526,7 +1526,7 @@ void launch_scan_repeated_in_nested(uint8_t const* message_data, int* error_flag, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); scan_repeated_in_nested_kernel<<>>( message_data, @@ -1553,7 +1553,7 @@ void launch_compute_nested_struct_locations(field_location const* child_locs, int* error_flag, rmm::cuda_stream_view stream) { - if (total_count == 0) { return; } + if (total_count == 0) return; auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_nested_struct_locations_kernel<<>>( child_locs, @@ -1576,7 +1576,7 @@ void launch_compute_grandchild_parent_locations(field_location const* parent_loc int* error_flag, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_grandchild_parent_locations_kernel<<>>( parent_locs, child_locs, child_idx, num_child_fields, gc_parent_abs, num_rows, error_flag); @@ -1591,7 +1591,7 @@ void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence cons int* error_flag, rmm::cuda_stream_view stream) { - if (total_count == 0) { return; } + if (total_count == 0) return; auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_virtual_parents_for_nested_repeated_kernel<<((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_msg_locations_from_occurrences_kernel<<>>( occurrences, list_offsets, base_offset, msg_locs, msg_row_offsets, total_count, error_flag); @@ -1628,7 +1628,7 @@ void launch_validate_enum_values(int32_t const* values, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); validate_enum_values_kernel<<>>( values, valid, row_has_invalid_enum, valid_enum_values, num_valid_values, num_rows); @@ -1643,7 +1643,7 @@ void launch_compute_enum_string_lengths(int32_t const* values, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); compute_enum_string_lengths_kernel<<>>( values, valid, valid_enum_values, enum_name_offsets, num_valid_values, lengths, num_rows); @@ -1660,7 +1660,7 @@ void launch_copy_enum_string_chars(int32_t const* values, int num_rows, rmm::cuda_stream_view stream) { - if (num_rows == 0) { return; } + if (num_rows == 0) return; auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); copy_enum_string_chars_kernel<<>>(values, valid, @@ -1720,7 +1720,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ bool propagate_to_rows, rmm::cuda_stream_view stream) { - if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } + if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) return; if (top_row_indices == nullptr) { CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), @@ -1742,7 +1742,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ [item_invalid = item_invalid.data(), top_row_indices, row_invalid = row_invalid.data()] __device__(int idx) { - if (item_invalid[idx]) { row_invalid[top_row_indices[idx]] = true; } + if (item_invalid[idx]) row_invalid[top_row_indices[idx]] = true; }); } @@ -1755,7 +1755,7 @@ void validate_enum_and_propagate_rows(rmm::device_uvector const& values bool propagate_to_rows, rmm::cuda_stream_view stream) { - if (num_items == 0 || valid_enums.empty()) { return; } + if (num_items == 0 || valid_enums.empty()) return; auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); auto d_valid_enums = cudf::detail::make_device_uvector_async( From c1f13c235e49a92b8f21a0d280df738b82b7c1de Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 7 Apr 2026 11:30:33 +0800 Subject: [PATCH 105/107] Replace copy_varlen_data_kernel with cub::DeviceMemcpy::Batched copy_varlen_data_kernel was removed in part0 review in favor of cub::DeviceMemcpy::Batched. Update the two remaining call sites in protobuf_builders.cu (repeated msg child varlen and repeated string column builders). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../cpp/src/protobuf/protobuf_builders.cu | 90 +++++++++++++++---- 1 file changed, 75 insertions(+), 15 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 1734a41704..c7ba8f9c4a 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -80,13 +80,43 @@ std::unique_ptr build_repeated_msg_child_varlen_column( d_child_locs.data(), child_idx, num_child_fields}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - total_count, - offsets_col->view().data(), - d_data.data(), - d_error.data()); + auto const* offsets_data = offsets_col->view().data(); + auto* chars_ptr = d_data.data(); + + auto src_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [message_data, loc_provider] __device__(int idx) -> void const* { + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + if (loc.offset < 0) return nullptr; + return static_cast(message_data + data_offset); + })); + auto dst_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + return static_cast(chars_ptr + offsets_data[idx]); + })); + auto size_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type([loc_provider] __device__(int idx) -> size_t { + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + if (loc.offset < 0) return 0; + return static_cast(loc.length); + })); + + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, total_count, stream.value()); + rmm::device_buffer temp_storage(temp_storage_bytes, stream, mr); + cub::DeviceMemcpy::Batched(temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + total_count, + stream.value()); } auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); @@ -550,14 +580,44 @@ std::unique_ptr build_repeated_string_column( rmm::device_uvector chars(total_chars, stream, mr); if (total_chars > 0) { - repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; - copy_varlen_data_kernel - <<>>(message_data, - loc_provider, - total_count, - str_offsets_col->view().data(), - chars.data(), - d_error.data()); + repeated_location_provider copy_provider{list_offsets, base_offset, d_occurrences.data()}; + auto const* offsets_data = str_offsets_col->view().data(); + auto* chars_ptr = chars.data(); + + auto src_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [message_data, copy_provider] __device__(int idx) -> void const* { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) return nullptr; + return static_cast(message_data + data_offset); + })); + auto dst_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + return static_cast(chars_ptr + offsets_data[idx]); + })); + auto size_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type([copy_provider] __device__(int idx) -> size_t { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) return 0; + return static_cast(loc.length); + })); + + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, total_count, stream.value()); + rmm::device_buffer temp_storage(temp_storage_bytes, stream, mr); + cub::DeviceMemcpy::Batched(temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + total_count, + stream.value()); } std::unique_ptr child_col; From 33e30f7789a10a8e6fa83b359c7d0c91c1625e39 Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 7 Apr 2026 11:33:34 +0800 Subject: [PATCH 106/107] Remove unused threads/blocks variables These were only needed by the removed copy_varlen_data_kernel. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../cpp/src/protobuf/protobuf_builders.cu | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index c7ba8f9c4a..e4b149db2f 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -35,14 +35,10 @@ std::unique_ptr build_repeated_msg_child_varlen_column( rmm::device_async_resource_ref mr) { if (total_count == 0) { - if (as_bytes) - return make_empty_column_safe(cudf::data_type{cudf::type_id::LIST}, stream, mr); + if (as_bytes) return make_empty_column_safe(cudf::data_type{cudf::type_id::LIST}, stream, mr); return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); } - auto const threads = THREADS_PER_BLOCK; - auto const blocks = static_cast((total_count + threads - 1u) / threads); - rmm::device_uvector d_lengths(total_count, stream, mr); thrust::transform( rmm::exec_policy_nosync(stream), @@ -93,13 +89,11 @@ std::unique_ptr build_repeated_msg_child_varlen_column( return static_cast(message_data + data_offset); })); auto dst_iter = cudf::detail::make_counting_transform_iterator( - 0, - cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + 0, cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { return static_cast(chars_ptr + offsets_data[idx]); })); auto size_iter = cudf::detail::make_counting_transform_iterator( - 0, - cuda::proclaim_return_type([loc_provider] __device__(int idx) -> size_t { + 0, cuda::proclaim_return_type([loc_provider] __device__(int idx) -> size_t { int32_t data_offset = 0; auto loc = loc_provider.get(idx, data_offset); if (loc.offset < 0) return 0; @@ -136,7 +130,6 @@ std::unique_ptr build_repeated_msg_child_varlen_column( total_count, std::move(offsets_col), d_data.release(), null_count, std::move(mask)); } - std::unique_ptr make_null_column(cudf::data_type dtype, cudf::size_type num_rows, rmm::cuda_stream_view stream, @@ -206,7 +199,6 @@ std::unique_ptr make_empty_column_safe(cudf::data_type dtype, } } - std::unique_ptr make_null_list_column_with_child( std::unique_ptr child_col, cudf::size_type num_rows, @@ -255,7 +247,8 @@ enum_string_lookup_tables make_enum_string_lookup_tables( auto d_valid_enums = cudf::detail::make_device_uvector_async( valid_enums, stream, cudf::get_current_device_resource_ref()); - auto h_name_offsets = cudf::detail::make_pinned_vector_async(valid_enums.size() + 1, stream); + auto h_name_offsets = + cudf::detail::make_pinned_vector_async(valid_enums.size() + 1, stream); std::fill(h_name_offsets.begin(), h_name_offsets.end(), 0); int64_t total_name_chars = 0; for (size_t k = 0; k < enum_name_bytes.size(); ++k) { @@ -479,12 +472,8 @@ std::unique_ptr build_repeated_enum_string_column( d_occurrences.end(), d_top_row_indices.begin(), [] __device__(repeated_occurrence const& occ) { return occ.row_idx; }); - propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, - d_row_force_null, - total_count, - d_top_row_indices.data(), - true, - stream); + propagate_invalid_enum_flags_to_rows( + d_elem_has_invalid_enum, d_row_force_null, total_count, d_top_row_indices.data(), true, stream); auto child_col = build_enum_string_values_column(enum_ints, elem_valid, lookup, total_count, stream, mr); @@ -594,13 +583,11 @@ std::unique_ptr build_repeated_string_column( return static_cast(message_data + data_offset); })); auto dst_iter = cudf::detail::make_counting_transform_iterator( - 0, - cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + 0, cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { return static_cast(chars_ptr + offsets_data[idx]); })); auto size_iter = cudf::detail::make_counting_transform_iterator( - 0, - cuda::proclaim_return_type([copy_provider] __device__(int idx) -> size_t { + 0, cuda::proclaim_return_type([copy_provider] __device__(int idx) -> size_t { int32_t data_offset = 0; auto loc = copy_provider.get(idx, data_offset); if (loc.offset < 0) return 0; From 7508a8092f3cfc6d4283ee07aa9708e05cfcae8b Mon Sep 17 00:00:00 2001 From: Haoyang Li Date: Tue, 7 Apr 2026 12:47:51 +0800 Subject: [PATCH 107/107] style Signed-off-by: Haoyang Li --- src/main/cpp/src/protobuf/protobuf.hpp | 1 - src/main/cpp/src/protobuf/protobuf_host_helpers.hpp | 1 - 2 files changed, 2 deletions(-) diff --git a/src/main/cpp/src/protobuf/protobuf.hpp b/src/main/cpp/src/protobuf/protobuf.hpp index 600e9b7d95..803f0c2d93 100644 --- a/src/main/cpp/src/protobuf/protobuf.hpp +++ b/src/main/cpp/src/protobuf/protobuf.hpp @@ -105,7 +105,6 @@ protobuf_field_meta_view make_field_meta_view(protobuf_decode_context const& con } // namespace detail - std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& binary_input, protobuf_decode_context const& context, rmm::cuda_stream_view stream, diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index 1e1910602e..c228329db2 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -113,7 +113,6 @@ inline cudf::type_id get_output_type_id(FieldT const& field) } } - template std::unique_ptr make_empty_struct_column_with_schema( SchemaT const& schema,