diff --git a/src/main/cpp/benchmarks/CMakeLists.txt b/src/main/cpp/benchmarks/CMakeLists.txt index 7641552622..00b39be8e1 100644 --- a/src/main/cpp/benchmarks/CMakeLists.txt +++ b/src/main/cpp/benchmarks/CMakeLists.txt @@ -86,3 +86,6 @@ ConfigureBench(GET_JSON_OBJECT_BENCH ConfigureBench(PARSE_URI_BENCH parse_uri.cpp) + +ConfigureBench(PROTOBUF_DECODE_BENCH + protobuf_decode.cu) diff --git a/src/main/cpp/benchmarks/protobuf_decode.cu b/src/main/cpp/benchmarks/protobuf_decode.cu new file mode 100644 index 0000000000..4551237e89 --- /dev/null +++ b/src/main/cpp/benchmarks/protobuf_decode.cu @@ -0,0 +1,1355 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace { + +// --------------------------------------------------------------------------- +// Protobuf wire-format encoding helpers (host side, for generating test data) +// --------------------------------------------------------------------------- + +void encode_varint(std::vector& buf, uint64_t value) +{ + while (value > 0x7F) { + buf.push_back(static_cast((value & 0x7F) | 0x80)); + value >>= 7; + } + buf.push_back(static_cast(value)); +} + +void encode_tag(std::vector& buf, int field_number, int wire_type) +{ + encode_varint(buf, (static_cast(field_number) << 3) | static_cast(wire_type)); +} + +void encode_varint_field(std::vector& buf, int field_number, int64_t value) +{ + encode_tag(buf, + field_number, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)); + encode_varint(buf, static_cast(value)); +} + +void encode_fixed32_field(std::vector& buf, int field_number, float value) +{ + encode_tag(buf, + field_number, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT)); + uint32_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + for (int i = 0; i < 4; i++) { + buf.push_back(static_cast(bits & 0xFF)); + bits >>= 8; + } +} + +void encode_fixed64_field(std::vector& buf, int field_number, double value) +{ + encode_tag(buf, + field_number, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT)); + uint64_t bits; + std::memcpy(&bits, &value, sizeof(bits)); + for (int i = 0; i < 8; i++) { + buf.push_back(static_cast(bits & 0xFF)); + bits >>= 8; + } +} + +void encode_len_field(std::vector& buf, int field_number, void const* data, size_t len) +{ + encode_tag( + buf, field_number, spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)); + encode_varint(buf, len); + auto const* p = static_cast(data); + buf.insert(buf.end(), p, p + len); +} + +void encode_string_field(std::vector& buf, int field_number, std::string const& s) +{ + encode_len_field(buf, field_number, s.data(), s.size()); +} + +// Encode a nested message: write its content into a temporary buffer, then emit as LEN. +template +void encode_nested_message(std::vector& buf, int field_number, Fn&& content_fn) +{ + std::vector inner; + content_fn(inner); + encode_len_field(buf, field_number, inner.data(), inner.size()); +} + +// Encode a packed repeated int32 field. +void encode_packed_repeated_int32(std::vector& buf, + int field_number, + std::vector const& values) +{ + std::vector packed; + for (auto v : values) { + encode_varint(packed, static_cast(static_cast(v))); + } + encode_len_field(buf, field_number, packed.data(), packed.size()); +} + +// --------------------------------------------------------------------------- +// Build a cuDF LIST column from host message buffers +// --------------------------------------------------------------------------- + +std::unique_ptr make_binary_column(std::vector> const& messages) +{ + auto stream = cudf::get_default_stream(); + auto mr = rmm::mr::get_current_device_resource(); + + std::vector h_offsets(messages.size() + 1); + h_offsets[0] = 0; + for (size_t i = 0; i < messages.size(); i++) { + h_offsets[i + 1] = h_offsets[i] + static_cast(messages[i].size()); + } + int32_t total_bytes = h_offsets.back(); + + std::vector h_data; + h_data.reserve(total_bytes); + for (auto const& m : messages) { + h_data.insert(h_data.end(), m.begin(), m.end()); + } + + rmm::device_buffer d_data(h_data.data(), h_data.size(), stream, mr); + rmm::device_buffer d_offsets(h_offsets.data(), h_offsets.size() * sizeof(int32_t), stream, mr); + stream.synchronize(); + + auto child_col = std::make_unique( + cudf::data_type{cudf::type_id::UINT8}, total_bytes, std::move(d_data), rmm::device_buffer{}, 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + static_cast(h_offsets.size()), + std::move(d_offsets), + rmm::device_buffer{}, + 0); + + return cudf::make_lists_column(static_cast(messages.size()), + std::move(offsets_col), + std::move(child_col), + 0, + rmm::device_buffer{}); +} + +// --------------------------------------------------------------------------- +// Schema + message generators for different benchmark scenarios +// --------------------------------------------------------------------------- + +using nfd = spark_rapids_jni::nested_field_descriptor; +using pb_field_location = spark_rapids_jni::protobuf_detail::field_location; +using pb_repeated_occurrence = spark_rapids_jni::protobuf_detail::repeated_occurrence; + +inline int32_t checked_size_to_i32(size_t value, char const* what) +{ + if (value > static_cast(std::numeric_limits::max())) { + throw std::overflow_error(std::string("benchmark protobuf size exceeds int32_t for ") + what); + } + return static_cast(value); +} + +void encode_string_field_record(std::vector& buf, + int field_number, + std::string const& s, + std::vector& out_occurrences, + int32_t row_idx) +{ + encode_tag(buf, + field_number, + /*spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)=*/2); + encode_varint(buf, s.size()); + auto const data_offset = checked_size_to_i32(buf.size(), "string field data offset"); + buf.insert(buf.end(), s.begin(), s.end()); + out_occurrences.push_back( + {row_idx, data_offset, checked_size_to_i32(s.size(), "string field length")}); +} + +// Case 1: Flat scalars only — many top-level scalar fields. +// message FlatMessage { +// int32 f1 = 1; +// int64 f2 = 2; +// ... +// float f_k = k; (cycling through int32, int64, float, double, bool) +// string s_k+1 = k+1; (a few string fields) +// } +struct FlatScalarCase { + int num_int_fields; + int num_string_fields; + + spark_rapids_jni::protobuf_decode_context build_context() const + { + spark_rapids_jni::protobuf_decode_context ctx; + ctx.fail_on_errors = true; + + // type_id cycle for integer-like fields + cudf::type_id int_types[] = {cudf::type_id::INT32, + cudf::type_id::INT64, + cudf::type_id::FLOAT32, + cudf::type_id::FLOAT64, + cudf::type_id::BOOL8}; + int wt_for_type[] = { + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT)}; + + int fn = 1; + for (int i = 0; i < num_int_fields; i++, fn++) { + int ti = i % 5; + auto ty = int_types[ti]; + int wt = wt_for_type[ti]; + int enc = spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT); + if (ty == cudf::type_id::FLOAT32) + enc = spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED); + if (ty == cudf::type_id::FLOAT64) + enc = spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED); + ctx.schema.push_back({fn, -1, 0, wt, ty, enc, false, false, false}); + } + for (int i = 0; i < num_string_fields; i++, fn++) { + ctx.schema.push_back( + {fn, + -1, + 0, + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN), + cudf::type_id::STRING, + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + false, + false, + false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(5, 50); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + int fn = 1; + for (int i = 0; i < num_int_fields; i++, fn++) { + int ti = i % 5; + switch (ti) { + case 0: encode_varint_field(buf, fn, int_dist(rng)); break; + case 1: encode_varint_field(buf, fn, int_dist(rng)); break; + case 2: encode_fixed32_field(buf, fn, static_cast(int_dist(rng))); break; + case 3: encode_fixed64_field(buf, fn, static_cast(int_dist(rng))); break; + case 4: encode_varint_field(buf, fn, rng() % 2); break; + } + } + for (int i = 0; i < num_string_fields; i++, fn++) { + int len = str_len_dist(rng); + std::string s(len, ' '); + for (int c = 0; c < len; c++) { + s[c] = alphabet[rng() % alphabet.size()]; + } + encode_string_field(buf, fn, s); + } + } + return messages; + } +}; + +// Case 2: Nested message — a top-level message with a nested struct child. +// message OuterMessage { +// int32 id = 1; +// string name = 2; +// InnerMessage inner = 3; +// } +// message InnerMessage { +// int32 x = 1; +// int64 y = 2; +// string data = 3; +// ... (num_inner_fields fields) +// } +struct NestedMessageCase { + int num_inner_fields; // scalar fields inside InnerMessage + + spark_rapids_jni::protobuf_decode_context build_context() const + { + spark_rapids_jni::protobuf_decode_context ctx; + ctx.fail_on_errors = true; + + // idx 0: id (int32, top-level) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: name (string, top-level) + ctx.schema.push_back({2, -1, 0, 2, cudf::type_id::STRING, 0, false, false, false}); + // idx 2: inner (STRUCT, top-level) + ctx.schema.push_back({3, -1, 0, 2, cudf::type_id::STRUCT, 0, false, false, false}); + + // Inner message children (parent_idx=2, depth=1) + cudf::type_id inner_types[] = { + cudf::type_id::INT32, cudf::type_id::INT64, cudf::type_id::STRING}; + int inner_wt[] = {0, 0, 2}; + + for (int i = 0; i < num_inner_fields; i++) { + int ti = i % 3; + ctx.schema.push_back({i + 1, 2, 1, inner_wt[ti], inner_types[ti], 0, false, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(5, 30); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + encode_varint_field(buf, 1, int_dist(rng)); + encode_string_field(buf, 2, random_string(str_len_dist(rng))); + + encode_nested_message(buf, 3, [&](std::vector& inner) { + for (int i = 0; i < num_inner_fields; i++) { + int ti = i % 3; + switch (ti) { + case 0: encode_varint_field(inner, i + 1, int_dist(rng)); break; + case 1: encode_varint_field(inner, i + 1, int_dist(rng)); break; + case 2: encode_string_field(inner, i + 1, random_string(str_len_dist(rng))); break; + } + } + }); + } + return messages; + } +}; + +// Case 3: Repeated fields — top-level repeated scalars and a repeated nested message. +// message RepeatedMessage { +// int32 id = 1; +// repeated int32 tags = 2; +// repeated string labels = 3; +// repeated Item items = 4; +// } +// message Item { +// int32 item_id = 1; +// string item_name = 2; +// int64 value = 3; +// } +struct RepeatedFieldCase { + int avg_tags_per_row; + int avg_labels_per_row; + int avg_items_per_row; + + spark_rapids_jni::protobuf_decode_context build_context() const + { + spark_rapids_jni::protobuf_decode_context ctx; + ctx.fail_on_errors = true; + + // idx 0: id (int32, scalar) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: tags (repeated int32, packed) + ctx.schema.push_back({2, -1, 0, 0, cudf::type_id::INT32, 0, true, false, false}); + // idx 2: labels (repeated string) + ctx.schema.push_back({3, -1, 0, 2, cudf::type_id::STRING, 0, true, false, false}); + // idx 3: items (repeated STRUCT) + ctx.schema.push_back({4, -1, 0, 2, cudf::type_id::STRUCT, 0, true, false, false}); + // idx 4: Item.item_id (int32, child of idx 3) + ctx.schema.push_back({1, 3, 1, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 5: Item.item_name (string, child of idx 3) + ctx.schema.push_back({2, 3, 1, 2, cudf::type_id::STRING, 0, false, false, false}); + // idx 6: Item.value (int64, child of idx 3) + ctx.schema.push_back({3, 3, 1, 0, cudf::type_id::INT64, 0, false, false, false}); + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(3, 20); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + // Vary count per row around the average (±50%) + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + + // id + encode_varint_field(buf, 1, int_dist(rng)); + + // tags (packed repeated int32) + { + int n = vary(avg_tags_per_row); + std::vector tags(n); + for (auto& t : tags) + t = int_dist(rng); + if (n > 0) encode_packed_repeated_int32(buf, 2, tags); + } + + // labels (unpacked repeated string) + { + int n = vary(avg_labels_per_row); + for (int i = 0; i < n; i++) { + encode_string_field(buf, 3, random_string(str_len_dist(rng))); + } + } + + // items (repeated nested message) + { + int n = vary(avg_items_per_row); + for (int i = 0; i < n; i++) { + encode_nested_message(buf, 4, [&](std::vector& inner) { + encode_varint_field(inner, 1, int_dist(rng)); + encode_string_field(inner, 2, random_string(str_len_dist(rng))); + encode_varint_field(inner, 3, int_dist(rng)); + }); + } + } + } + return messages; + } +}; + +// Case 4: Wide repeated message — stress-tests repeated struct child scanning. +// message WideRepeatedMessage { +// int32 id = 1; +// repeated Item items = 2; +// } +// message Item { +// int32 / int64 / float / double / bool / string child fields ... +// ... (num_child_fields fields) +// } +// +// This case is intentionally generic and contains no customer schema details. +// It is designed to exercise `scan_repeated_message_children_kernel` with a +// wide repeated STRUCT payload similar in shape to real-world schema-projection +// workloads. +struct WideRepeatedMessageCase { + int num_child_fields; + int avg_items_per_row; + + spark_rapids_jni::protobuf_decode_context build_context() const + { + spark_rapids_jni::protobuf_decode_context ctx; + ctx.fail_on_errors = true; + + // idx 0: id (scalar) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: items (repeated STRUCT) + ctx.schema.push_back({2, -1, 0, 2, cudf::type_id::STRUCT, 0, true, false, false}); + + cudf::type_id child_types[] = {cudf::type_id::INT32, + cudf::type_id::INT64, + cudf::type_id::FLOAT32, + cudf::type_id::FLOAT64, + cudf::type_id::BOOL8, + cudf::type_id::STRING}; + int child_wt[] = {spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I32BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::I64BIT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::VARINT), + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN)}; + int child_enc[] = {spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::FIXED), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT), + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT)}; + + // Keep strings sparse so the case remains dominated by wide child scanning + // rather than varlen copy traffic. + for (int i = 0; i < num_child_fields; i++) { + int ti = (i % 10 == 9) ? 5 : (i % 5); + ctx.schema.push_back( + {i + 1, 1, 1, child_wt[ti], child_types[ti], child_enc[ti], false, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(6, 18); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + encode_varint_field(buf, 1, int_dist(rng)); + + int n = vary(avg_items_per_row); + for (int item_idx = 0; item_idx < n; item_idx++) { + encode_nested_message(buf, 2, [&](std::vector& inner) { + for (int i = 0; i < num_child_fields; i++) { + int ti = (i % 10 == 9) ? 5 : (i % 5); + int fn = i + 1; + switch (ti) { + case 0: encode_varint_field(inner, fn, int_dist(rng)); break; + case 1: encode_varint_field(inner, fn, int_dist(rng)); break; + case 2: encode_fixed32_field(inner, fn, static_cast(int_dist(rng))); break; + case 3: encode_fixed64_field(inner, fn, static_cast(int_dist(rng))); break; + case 4: encode_varint_field(inner, fn, rng() % 2); break; + case 5: encode_string_field(inner, fn, random_string(str_len_dist(rng))); break; + } + } + }); + } + } + return messages; + } +}; + +// Case 5: Repeated child lists — stress-tests repeated fields inside a repeated +// struct child, which exercises build_repeated_child_list_column(). +// message OuterMessage { +// int32 id = 1; +// repeated Item items = 2; +// } +// message Item { +// repeated int32 r_int_* = 1..N +// repeated string r_str_* = ... +// } +// +// This case is intentionally generic and contains no customer schema details. +struct RepeatedChildListCase { + int num_repeated_children; + int avg_items_per_row; + int avg_child_elems; + std::string child_mix; + + bool child_is_string(int child_idx) const + { + if (child_mix == "string_only") return true; + if (child_mix == "int_only") return false; + return (child_idx % 4 == 3); + } + + spark_rapids_jni::protobuf_decode_context build_context() const + { + spark_rapids_jni::protobuf_decode_context ctx; + ctx.fail_on_errors = true; + + // idx 0: id (scalar) + ctx.schema.push_back({1, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + // idx 1: items (repeated STRUCT) + ctx.schema.push_back({2, -1, 0, 2, cudf::type_id::STRUCT, 0, true, false, false}); + + for (int i = 0; i < num_repeated_children; i++) { + bool as_string = child_is_string(i); + ctx.schema.push_back({i + 1, + 1, + 1, + as_string ? 2 : 0, + as_string ? cudf::type_id::STRING : cudf::type_id::INT32, + 0, + true, + false, + false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(4, 16); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + encode_varint_field(buf, 1, int_dist(rng)); + + int num_items = vary(avg_items_per_row); + for (int item_idx = 0; item_idx < num_items; item_idx++) { + encode_nested_message(buf, 2, [&](std::vector& inner) { + for (int child_idx = 0; child_idx < num_repeated_children; child_idx++) { + int fn = child_idx + 1; + bool is_str = child_is_string(child_idx); + int num_elems = vary(avg_child_elems); + if (is_str) { + for (int j = 0; j < num_elems; j++) { + encode_string_field(inner, fn, random_string(str_len_dist(rng))); + } + } else { + if (num_elems > 0) { + std::vector vals(num_elems); + for (auto& v : vals) + v = int_dist(rng); + encode_packed_repeated_int32(inner, fn, vals); + } + } + } + }); + } + } + return messages; + } +}; + +struct RepeatedChildStringBenchData { + std::vector> messages; + std::vector parent_locs; + std::vector> counts_by_child; + std::vector> occurrences_by_child; +}; + +struct RepeatedChildStringOnlyCase { + int num_repeated_children; + int avg_child_elems; + + RepeatedChildStringBenchData generate_data(int num_rows, std::mt19937& rng) const + { + RepeatedChildStringBenchData out; + out.messages.resize(num_rows); + out.parent_locs.resize(num_rows); + out.counts_by_child.resize(num_repeated_children); + out.occurrences_by_child.resize(num_repeated_children); + + std::uniform_int_distribution str_len_dist(4, 16); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int row = 0; row < num_rows; row++) { + auto& buf = out.messages[row]; + for (int child_idx = 0; child_idx < num_repeated_children; child_idx++) { + int fn = child_idx + 1; + int num_elems = vary(avg_child_elems); + out.counts_by_child[child_idx].push_back(num_elems); + for (int j = 0; j < num_elems; j++) { + encode_string_field_record( + buf, fn, random_string(str_len_dist(rng)), out.occurrences_by_child[child_idx], row); + } + } + out.parent_locs[row] = {0, static_cast(buf.size())}; + } + return out; + } +}; + +// Case 6: Many repeated fields — stress-tests per-repeated-field sync overhead. +// message WideRepeatedMessage { +// int32 id = 1; +// repeated int32 r_int_1 = 2; +// repeated int32 r_int_2 = 3; +// ... +// repeated string r_str_1 = N; +// repeated string r_str_2 = N+1; +// ... +// } +struct ManyRepeatedFieldsCase { + int num_repeated_int; + int num_repeated_str; + + spark_rapids_jni::protobuf_decode_context build_context() const + { + spark_rapids_jni::protobuf_decode_context ctx; + ctx.fail_on_errors = true; + + int fn = 1; + // idx 0: id (scalar) + ctx.schema.push_back({fn++, -1, 0, 0, cudf::type_id::INT32, 0, false, false, false}); + + for (int i = 0; i < num_repeated_int; i++) { + ctx.schema.push_back({fn++, -1, 0, 0, cudf::type_id::INT32, 0, true, false, false}); + } + for (int i = 0; i < num_repeated_str; i++) { + ctx.schema.push_back({fn++, -1, 0, 2, cudf::type_id::STRING, 0, true, false, false}); + } + + size_t n = ctx.schema.size(); + for (auto const& f : ctx.schema) { + ctx.schema_output_types.emplace_back(f.output_type); + } + ctx.default_ints.resize(n, 0); + ctx.default_floats.resize(n, 0.0); + ctx.default_bools.resize(n, false); + ctx.default_strings.resize(n); + ctx.enum_valid_values.resize(n); + ctx.enum_names.resize(n); + return ctx; + } + + std::vector> generate_messages(int num_rows, + int avg_elems_per_field, + std::mt19937& rng) const + { + std::vector> messages(num_rows); + std::uniform_int_distribution int_dist(0, 100000); + std::uniform_int_distribution str_len_dist(3, 15); + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + + auto random_string = [&](int len) { + std::string s(len, ' '); + for (int c = 0; c < len; c++) + s[c] = alphabet[rng() % alphabet.size()]; + return s; + }; + auto vary = [&](int avg) -> int { + int lo = std::max(0, avg / 2); + int hi = avg + avg / 2; + return std::uniform_int_distribution(lo, std::max(lo, hi))(rng); + }; + + for (int r = 0; r < num_rows; r++) { + auto& buf = messages[r]; + int fn = 1; + + encode_varint_field(buf, fn++, int_dist(rng)); + + for (int i = 0; i < num_repeated_int; i++) { + int cur_fn = fn++; + int n = vary(avg_elems_per_field); + if (n > 0) { + std::vector vals(n); + for (auto& v : vals) + v = int_dist(rng); + encode_packed_repeated_int32(buf, cur_fn, vals); + } + } + for (int i = 0; i < num_repeated_str; i++) { + int cur_fn = fn++; + int n = vary(avg_elems_per_field); + for (int j = 0; j < n; j++) { + encode_string_field(buf, cur_fn, random_string(str_len_dist(rng))); + } + } + } + return messages; + } +}; + +} // anonymous namespace + +// =========================================================================== +// Benchmark 1: Flat scalars — measures per-field extraction overhead +// =========================================================================== +static void BM_protobuf_flat_scalars(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_fields = static_cast(state.get_int64("num_fields")); + int const num_str = std::max(1, num_fields / 10); + int const num_int = num_fields - num_str; + + FlatScalarCase flat_case{num_int, num_str}; + auto ctx = flat_case.build_context(); + + std::mt19937 rng(42); + auto messages = flat_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_flat_scalars) + .set_name("Protobuf Flat Scalars") + .add_int64_axis("num_rows", {10'000, 100'000, 500'000}) + .add_int64_axis("num_fields", {10, 50, 200}); + +// =========================================================================== +// Benchmark 2: Nested messages — measures nested struct build overhead +// =========================================================================== +static void BM_protobuf_nested(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const inner_fields = static_cast(state.get_int64("inner_fields")); + + NestedMessageCase nested_case{inner_fields}; + auto ctx = nested_case.build_context(); + + std::mt19937 rng(42); + auto messages = nested_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_nested) + .set_name("Protobuf Nested Message") + .add_int64_axis("num_rows", {10'000, 100'000, 500'000}) + .add_int64_axis("inner_fields", {5, 20, 100}); + +// =========================================================================== +// Benchmark 3: Repeated fields — measures repeated field pipeline overhead +// =========================================================================== +static void BM_protobuf_repeated(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const avg_items = static_cast(state.get_int64("avg_items")); + + RepeatedFieldCase rep_case{/*avg_tags=*/5, /*avg_labels=*/3, /*avg_items=*/avg_items}; + auto ctx = rep_case.build_context(); + + std::mt19937 rng(42); + auto messages = rep_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated) + .set_name("Protobuf Repeated Fields") + .add_int64_axis("num_rows", {10'000, 100'000}) + .add_int64_axis("avg_items", {1, 5, 20}); + +// =========================================================================== +// Benchmark 4: Wide repeated message — measures repeated struct child scan cost +// =========================================================================== +static void BM_protobuf_wide_repeated_message(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_child_fields = static_cast(state.get_int64("num_child_fields")); + auto const avg_items = static_cast(state.get_int64("avg_items")); + + WideRepeatedMessageCase wide_case{num_child_fields, avg_items}; + auto ctx = wide_case.build_context(); + + std::mt19937 rng(42); + auto messages = wide_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_wide_repeated_message) + .set_name("Protobuf Wide Repeated Message") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_child_fields", {20, 100, 200}) + .add_int64_axis("avg_items", {1, 5, 10}); + +// =========================================================================== +// Benchmark 5: Repeated child lists — measures repeated-in-nested list overhead +// =========================================================================== +static void BM_protobuf_repeated_child_lists(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_repeated_children = static_cast(state.get_int64("num_repeated_children")); + auto const avg_items = static_cast(state.get_int64("avg_items")); + auto const avg_child_elems = static_cast(state.get_int64("avg_child_elems")); + auto const child_mix = state.get_string("child_mix"); + + RepeatedChildListCase list_case{ + num_repeated_children, avg_items, avg_child_elems, std::string(child_mix)}; + auto ctx = list_case.build_context(); + + std::mt19937 rng(42); + auto messages = list_case.generate_messages(num_rows, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated_child_lists) + .set_name("Protobuf Repeated Child Lists") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_repeated_children", {1, 4, 8}) + .add_int64_axis("avg_items", {1, 5}) + .add_int64_axis("avg_child_elems", {1, 5}) + .add_string_axis("child_mix", {"int_only", "mixed", "string_only"}); + +// =========================================================================== +// Benchmark 6: Repeated child string count+scan only +// =========================================================================== +static void BM_protobuf_repeated_child_string_count_scan(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_repeated_children = static_cast(state.get_int64("num_repeated_children")); + auto const avg_child_elems = static_cast(state.get_int64("avg_child_elems")); + + RepeatedChildStringOnlyCase list_case{num_repeated_children, avg_child_elems}; + std::mt19937 rng(42); + auto data = list_case.generate_data(num_rows, rng); + auto binary_col = make_binary_column(data.messages); + + cudf::lists_column_view in_list(binary_col->view()); + auto const* row_offsets = in_list.offsets().data(); + auto const* message_data = reinterpret_cast(in_list.child().data()); + auto const message_data_size = static_cast(in_list.child().size()); + + auto stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + rmm::device_uvector d_parent_locs(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), + data.parent_locs.data(), + num_rows * sizeof(pb_field_location), + cudaMemcpyHostToDevice, + stream.value())); + + std::vector h_schema( + num_repeated_children); + for (int i = 0; i < num_repeated_children; i++) { + h_schema[i].field_number = i + 1; + h_schema[i].parent_idx = -1; + h_schema[i].depth = 0; + h_schema[i].wire_type = + spark_rapids_jni::wire_type_value(spark_rapids_jni::proto_wire_type::LEN); + h_schema[i].output_type_id = static_cast(cudf::type_id::STRING); + h_schema[i].encoding = + spark_rapids_jni::encoding_value(spark_rapids_jni::proto_encoding::DEFAULT); + h_schema[i].is_repeated = true; + h_schema[i].is_required = false; + h_schema[i].has_default_value = false; + } + rmm::device_uvector d_schema( + num_repeated_children, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), + h_schema.data(), + num_repeated_children * sizeof(h_schema[0]), + cudaMemcpyHostToDevice, + stream.value())); + + std::vector h_rep_indices(num_repeated_children); + for (int i = 0; i < num_repeated_children; i++) { + CUDF_EXPECTS(h_schema[i].is_repeated, + "count_repeated_in_nested_kernel benchmark expects repeated child fields"); + CUDF_EXPECTS(h_schema[i].depth == 0, + "count_repeated_in_nested_kernel benchmark expects pre-filtered child depth 0"); + h_rep_indices[i] = i; + } + rmm::device_uvector d_rep_indices(num_repeated_children, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_indices.data(), + h_rep_indices.data(), + num_repeated_children * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + + size_t total_bytes = 0; + for (auto const& m : data.messages) + total_bytes += m.size(); + + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + rmm::device_uvector d_rep_info( + static_cast(num_rows) * num_repeated_children, stream, mr); + spark_rapids_jni::protobuf_detail:: + count_repeated_in_nested_kernel<<<(num_rows + 255) / 256, 256, 0, stream.value()>>>( + message_data, + message_data_size, + row_offsets, + 0, + d_parent_locs.data(), + num_rows, + d_schema.data(), + num_repeated_children, + d_rep_info.data(), + num_repeated_children, + d_rep_indices.data(), + d_error.data()); + + struct rep_work { + rmm::device_uvector counts; + rmm::device_uvector offsets; + int32_t total_count{0}; + std::unique_ptr> occs; + rep_work(int n, rmm::cuda_stream_view s, rmm::device_async_resource_ref m) + : counts(n, s, m), offsets(n + 1, s, m) + { + } + }; + + std::vector> work; + work.reserve(num_repeated_children); + for (int ri = 0; ri < num_repeated_children; ri++) { + auto& w = *work.emplace_back(std::make_unique(num_rows, stream, mr)); + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + w.counts.data(), + spark_rapids_jni::protobuf_detail::extract_strided_count{ + d_rep_info.data(), ri, num_repeated_children}); + CUDF_CUDA_TRY(cudaMemsetAsync(w.offsets.data(), 0, sizeof(int32_t), stream.value())); + thrust::inclusive_scan( + rmm::exec_policy_nosync(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); + CUDF_CUDA_TRY(cudaMemcpyAsync(&w.total_count, + w.offsets.data() + num_rows, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + } + stream.synchronize(); + + for (int ri = 0; ri < num_repeated_children; ri++) { + auto& w = *work[ri]; + if (w.total_count > 0) { + w.occs = + std::make_unique>(w.total_count, stream, mr); + spark_rapids_jni::protobuf_detail:: + scan_repeated_in_nested_kernel<<<(num_rows + 255) / 256, 256, 0, stream.value()>>>( + message_data, + message_data_size, + row_offsets, + 0, + d_parent_locs.data(), + num_rows, + d_schema.data(), + w.offsets.data(), + d_rep_indices.data() + ri, + w.occs->data(), + d_error.data()); + } + } + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated_child_string_count_scan) + .set_name("Protobuf Repeated Child String CountScan") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_repeated_children", {1, 4, 8}) + .add_int64_axis("avg_child_elems", {1, 5}); + +// =========================================================================== +// Benchmark 7: Repeated child string build-only +// =========================================================================== +static void BM_protobuf_repeated_child_string_build(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_repeated_children = static_cast(state.get_int64("num_repeated_children")); + auto const avg_child_elems = static_cast(state.get_int64("avg_child_elems")); + + RepeatedChildStringOnlyCase list_case{num_repeated_children, avg_child_elems}; + std::mt19937 rng(42); + auto data = list_case.generate_data(num_rows, rng); + auto binary_col = make_binary_column(data.messages); + + cudf::lists_column_view in_list(binary_col->view()); + auto const* row_offsets = in_list.offsets().data(); + auto const* message_data = reinterpret_cast(in_list.child().data()); + + auto stream = cudf::get_default_stream(); + auto mr = cudf::get_current_device_resource_ref(); + + rmm::device_uvector d_parent_locs(num_rows, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_parent_locs.data(), + data.parent_locs.data(), + num_rows * sizeof(pb_field_location), + cudaMemcpyHostToDevice, + stream.value())); + + struct precomputed_child { + rmm::device_uvector counts; + rmm::device_uvector occs; + int total_count; + precomputed_child(int nrows, + int total, + rmm::cuda_stream_view s, + rmm::device_async_resource_ref m) + : counts(nrows, s, m), occs(total, s, m), total_count(total) + { + } + }; + + std::vector> children; + children.reserve(num_repeated_children); + for (int i = 0; i < num_repeated_children; i++) { + int total = static_cast(data.occurrences_by_child[i].size()); + auto& c = + *children.emplace_back(std::make_unique(num_rows, total, stream, mr)); + CUDF_CUDA_TRY(cudaMemcpyAsync(c.counts.data(), + data.counts_by_child[i].data(), + num_rows * sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + if (total > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(c.occs.data(), + data.occurrences_by_child[i].data(), + total * sizeof(pb_repeated_occurrence), + cudaMemcpyHostToDevice, + stream.value())); + } + } + + size_t total_bytes = 0; + for (auto const& m : data.messages) + total_bytes += m.size(); + + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + + for (int i = 0; i < num_repeated_children; i++) { + auto& c = *children[i]; + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy_nosync(stream), c.counts.begin(), c.counts.end(), list_offs.begin(), 0); + CUDF_CUDA_TRY(cudaMemcpyAsync(list_offs.data() + num_rows, + &c.total_count, + sizeof(int32_t), + cudaMemcpyHostToDevice, + stream.value())); + + spark_rapids_jni::protobuf_detail::nested_repeated_location_provider nr_loc{ + row_offsets, 0, d_parent_locs.data(), c.occs.data()}; + auto valid_fn = [] __device__(cudf::size_type) { return true; }; + std::vector empty_default; + auto child_values = + spark_rapids_jni::protobuf_detail::extract_and_build_string_or_bytes_column(false, + message_data, + c.total_count, + nr_loc, + nr_loc, + valid_fn, + false, + empty_default, + d_error, + stream, + mr); + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + auto result = cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}); + } + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_repeated_child_string_build) + .set_name("Protobuf Repeated Child String Build") + .add_int64_axis("num_rows", {10'000, 20'000}) + .add_int64_axis("num_repeated_children", {1, 4, 8}) + .add_int64_axis("avg_child_elems", {1, 5}); + +// =========================================================================== +// Benchmark 8: Many repeated fields — measures per-field sync overhead at scale +// =========================================================================== +static void BM_protobuf_many_repeated(nvbench::state& state) +{ + auto const num_rows = static_cast(state.get_int64("num_rows")); + auto const num_rep_fields = static_cast(state.get_int64("num_rep_fields")); + + int const num_rep_str = std::max(1, num_rep_fields / 5); + int const num_rep_int = num_rep_fields - num_rep_str; + + ManyRepeatedFieldsCase many_case{num_rep_int, num_rep_str}; + auto ctx = many_case.build_context(); + + std::mt19937 rng(42); + auto messages = many_case.generate_messages(num_rows, /*avg_elems=*/3, rng); + auto binary_col = make_binary_column(messages); + + size_t total_bytes = 0; + for (auto const& m : messages) + total_bytes += m.size(); + + auto stream = cudf::get_default_stream(); + state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value())); + state.exec(nvbench::exec_tag::sync, [&](nvbench::launch&) { + auto result = spark_rapids_jni::decode_protobuf_to_struct(binary_col->view(), ctx, stream); + }); + + state.add_element_count(num_rows, "Rows"); + state.add_global_memory_reads(total_bytes); +} + +NVBENCH_BENCH(BM_protobuf_many_repeated) + .set_name("Protobuf Many Repeated Fields") + .add_int64_axis("num_rows", {10'000, 100'000}) + .add_int64_axis("num_rep_fields", {10, 30, 50}); diff --git a/src/main/cpp/src/protobuf/protobuf.cu b/src/main/cpp/src/protobuf/protobuf.cu index 1082843cad..ab34ceccf2 100644 --- a/src/main/cpp/src/protobuf/protobuf.cu +++ b/src/main/cpp/src/protobuf/protobuf.cu @@ -19,6 +19,9 @@ #include +#include + +#include #include #include @@ -28,6 +31,112 @@ namespace detail { namespace { +void propagate_nulls_to_descendants(cudf::column& col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr); + +void apply_parent_mask_to_row_aligned_column(cudf::column& col, + cudf::bitmask_type const* parent_mask_ptr, + cudf::size_type parent_null_count, + cudf::size_type num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (parent_null_count == 0) { return; } + auto child_view = col.mutable_view(); + CUDF_EXPECTS(child_view.size() == num_rows, + "struct child size must match parent row count for null propagation"); + + if (child_view.nullable()) { + auto const child_mask_words = + cudf::num_bitmask_words(static_cast(child_view.size() + child_view.offset())); + std::array masks{child_view.null_mask(), parent_mask_ptr}; + std::array begin_bits{child_view.offset(), 0}; + auto const valid_count = cudf::detail::inplace_bitmask_and( + cudf::device_span(child_view.null_mask(), child_mask_words), + cudf::host_span(masks.data(), masks.size()), + cudf::host_span(begin_bits.data(), begin_bits.size()), + child_view.size(), + stream); + col.set_null_count(child_view.size() - valid_count); + } else { + CUDF_EXPECTS(child_view.offset() == 0, + "non-nullable child with nonzero offset not supported for null propagation"); + auto child_mask = cudf::detail::copy_bitmask(parent_mask_ptr, 0, num_rows, stream, mr); + col.set_null_mask(std::move(child_mask), parent_null_count); + } +} + +void propagate_list_nulls_to_descendants(cudf::column& list_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (list_col.type().id() != cudf::type_id::LIST || list_col.null_count() == 0) { return; } + + cudf::lists_column_view const list_view(list_col.view()); + auto const* list_mask_ptr = list_view.null_mask(); + auto const num_rows = list_view.size(); + auto& child = list_col.child(cudf::lists_column_view::child_column_index); + auto const child_size = child.size(); + if (child_size == 0) { return; } + + CUDF_EXPECTS(list_view.offset() == 0, + "decoder list null propagation expects unsliced list columns"); + auto const* offsets_begin = list_view.offsets_begin(); + auto const* offsets_end = list_view.offsets_end(); + // LIST children are not row-aligned with their parent. Expand the list-row null mask across + // every covered child element so direct access to the backing child column also observes nulls. + auto [element_mask, element_null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(child_size), + [list_mask_ptr, offsets_begin, offsets_end] __device__(cudf::size_type idx) { + auto const it = thrust::upper_bound(thrust::seq, offsets_begin, offsets_end, idx); + auto const row = static_cast(it - offsets_begin) - 1; + return list_mask_ptr == nullptr || cudf::bit_is_set(list_mask_ptr, row); + }, + stream, + mr); + + apply_parent_mask_to_row_aligned_column( + child, + static_cast(element_mask.data()), + element_null_count, + child_size, + stream, + mr); + propagate_nulls_to_descendants(child, stream, mr); +} + +void propagate_struct_nulls_to_descendants(cudf::column& struct_col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (struct_col.type().id() != cudf::type_id::STRUCT || struct_col.null_count() == 0) { return; } + + auto const struct_view = struct_col.view(); + auto const* struct_mask_ptr = struct_view.null_mask(); + auto const num_rows = struct_view.size(); + auto const null_count = struct_col.null_count(); + + for (cudf::size_type i = 0; i < struct_col.num_children(); ++i) { + auto& child = struct_col.child(i); + apply_parent_mask_to_row_aligned_column( + child, struct_mask_ptr, null_count, num_rows, stream, mr); + propagate_nulls_to_descendants(child, stream, mr); + } +} + +void propagate_nulls_to_descendants(cudf::column& col, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + switch (col.type().id()) { + case cudf::type_id::STRUCT: propagate_struct_nulls_to_descendants(col, stream, mr); break; + case cudf::type_id::LIST: propagate_list_nulls_to_descendants(col, stream, mr); break; + default: break; + } +} + std::unique_ptr make_null_column_with_schema( std::vector const& schema, int schema_idx, @@ -106,137 +215,102 @@ bool is_encoding_compatible(nested_field_descriptor const& field, cudf::data_typ void validate_decode_context(protobuf_decode_context const& context) { auto const num_fields = context.schema.size(); - if (context.default_ints.size() != num_fields) { - CUDF_FAIL("protobuf decode context: default_ints size mismatch with schema (" + - std::to_string(context.default_ints.size()) + " vs " + std::to_string(num_fields) + - ")", - std::invalid_argument); - } - if (context.default_floats.size() != num_fields) { - CUDF_FAIL("protobuf decode context: default_floats size mismatch with schema (" + - std::to_string(context.default_floats.size()) + " vs " + - std::to_string(num_fields) + ")", - std::invalid_argument); - } - if (context.default_bools.size() != num_fields) { - CUDF_FAIL("protobuf decode context: default_bools size mismatch with schema (" + - std::to_string(context.default_bools.size()) + " vs " + std::to_string(num_fields) + - ")", - std::invalid_argument); - } - if (context.default_strings.size() != num_fields) { - CUDF_FAIL("protobuf decode context: default_strings size mismatch with schema (" + - std::to_string(context.default_strings.size()) + " vs " + - std::to_string(num_fields) + ")", - std::invalid_argument); - } - if (context.enum_valid_values.size() != num_fields) { - CUDF_FAIL("protobuf decode context: enum_valid_values size mismatch with schema (" + - std::to_string(context.enum_valid_values.size()) + " vs " + - std::to_string(num_fields) + ")", - std::invalid_argument); - } - if (context.enum_names.size() != num_fields) { - CUDF_FAIL("protobuf decode context: enum_names size mismatch with schema (" + - std::to_string(context.enum_names.size()) + " vs " + std::to_string(num_fields) + - ")", - std::invalid_argument); - } + CUDF_EXPECTS(context.default_ints.size() == num_fields, + "protobuf decode context: default_ints size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_floats.size() == num_fields, + "protobuf decode context: default_floats size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_bools.size() == num_fields, + "protobuf decode context: default_bools size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.default_strings.size() == num_fields, + "protobuf decode context: default_strings size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.enum_valid_values.size() == num_fields, + "protobuf decode context: enum_valid_values size mismatch", + std::invalid_argument); + CUDF_EXPECTS(context.enum_names.size() == num_fields, + "protobuf decode context: enum_names size mismatch", + std::invalid_argument); - std::unordered_set seen_field_numbers; + std::set> seen_field_numbers; for (size_t i = 0; i < num_fields; ++i) { auto const& field = context.schema[i]; auto const type = cudf::data_type{field.output_type}; - if (field.field_number <= 0 || field.field_number > MAX_FIELD_NUMBER) { - CUDF_FAIL("protobuf decode context: invalid field number at field " + std::to_string(i), - std::invalid_argument); - } - if (field.depth < 0 || field.depth >= MAX_NESTING_DEPTH) { - CUDF_FAIL("protobuf decode context: field depth exceeds supported limit at field " + - std::to_string(i), - std::invalid_argument); - } - if (field.parent_idx < -1 || field.parent_idx >= static_cast(i)) { - CUDF_FAIL("protobuf decode context: invalid parent index at field " + std::to_string(i), - std::invalid_argument); - } - auto const key = (static_cast(static_cast(field.parent_idx)) << 32) | - static_cast(field.field_number); - if (!seen_field_numbers.insert(key).second) { - CUDF_FAIL("protobuf decode context: duplicate field number under same parent at field " + - std::to_string(i), - std::invalid_argument); - } + CUDF_EXPECTS(field.field_number > 0 && field.field_number <= MAX_FIELD_NUMBER, + "protobuf decode context: invalid field number at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(field.depth >= 0 && field.depth < MAX_NESTING_DEPTH, + "protobuf decode context: field depth exceeds limit at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(field.parent_idx >= -1 && field.parent_idx < static_cast(i), + "protobuf decode context: invalid parent index at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(seen_field_numbers.emplace(field.parent_idx, field.field_number).second, + "protobuf decode context: duplicate field number under same parent at field " + + std::to_string(i), + std::invalid_argument); if (field.parent_idx == -1) { - if (field.depth != 0) { - CUDF_FAIL("protobuf decode context: top-level field must have depth 0 at field " + - std::to_string(i), - std::invalid_argument); - } + CUDF_EXPECTS( + field.depth == 0, + "protobuf decode context: top-level field must have depth 0 at field " + std::to_string(i), + std::invalid_argument); } else { auto const& parent = context.schema[field.parent_idx]; - if (field.depth != parent.depth + 1) { - CUDF_FAIL("protobuf decode context: child depth mismatch at field " + std::to_string(i), - std::invalid_argument); - } - if (cudf::data_type{context.schema[field.parent_idx].output_type}.id() != - cudf::type_id::STRUCT) { - CUDF_FAIL("protobuf decode context: parent must be STRUCT at field " + std::to_string(i), - std::invalid_argument); - } + CUDF_EXPECTS(field.depth == parent.depth + 1, + "protobuf decode context: child depth mismatch at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(context.schema[field.parent_idx].output_type == cudf::type_id::STRUCT, + "protobuf decode context: parent must be STRUCT at field " + std::to_string(i), + std::invalid_argument); } - if (field.wire_type != proto_wire_type::VARINT && field.wire_type != proto_wire_type::I64BIT && - field.wire_type != proto_wire_type::LEN && field.wire_type != proto_wire_type::I32BIT) { - CUDF_FAIL("protobuf decode context: invalid wire type at field " + std::to_string(i), - std::invalid_argument); - } - if (field.encoding < proto_encoding::DEFAULT || field.encoding > proto_encoding::ENUM_STRING) { - CUDF_FAIL("protobuf decode context: invalid encoding at field " + std::to_string(i), - std::invalid_argument); - } - if (field.is_repeated && field.is_required) { - CUDF_FAIL("protobuf decode context: field cannot be both repeated and required at field " + - std::to_string(i), - std::invalid_argument); - } - if (field.is_repeated && field.has_default_value) { - CUDF_FAIL("protobuf decode context: repeated field cannot carry default value at field " + - std::to_string(i), - std::invalid_argument); - } - if (field.has_default_value && - (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)) { - CUDF_FAIL("protobuf decode context: STRUCT/LIST field cannot carry default value at field " + - std::to_string(i), - std::invalid_argument); - } - if (!is_encoding_compatible(field, type)) { - CUDF_FAIL("protobuf decode context: incompatible wire type/encoding/output type at field " + - std::to_string(i), - std::invalid_argument); - } + CUDF_EXPECTS( + field.wire_type == proto_wire_type::VARINT || field.wire_type == proto_wire_type::I64BIT || + field.wire_type == proto_wire_type::LEN || field.wire_type == proto_wire_type::I32BIT, + "protobuf decode context: invalid wire type at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS( + field.encoding >= proto_encoding::DEFAULT && field.encoding <= proto_encoding::ENUM_STRING, + "protobuf decode context: invalid encoding at field " + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(!(field.is_repeated && field.is_required), + "protobuf decode context: field cannot be both repeated and required at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(!(field.is_repeated && field.has_default_value), + "protobuf decode context: repeated field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(!(field.has_default_value && + (type.id() == cudf::type_id::STRUCT || type.id() == cudf::type_id::LIST)), + "protobuf decode context: STRUCT/LIST field cannot carry default value at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS(is_encoding_compatible(field, type), + "protobuf decode context: incompatible wire type/encoding/output type at field " + + std::to_string(i), + std::invalid_argument); if (field.encoding == proto_encoding::ENUM_STRING) { - if (context.enum_valid_values[i].empty() || context.enum_names[i].empty()) { - CUDF_FAIL( - "protobuf decode context: enum-as-string field requires non-empty metadata at field " + - std::to_string(i), - std::invalid_argument); - } - if (context.enum_valid_values[i].size() != context.enum_names[i].size()) { - CUDF_FAIL( - "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i), - std::invalid_argument); - } + CUDF_EXPECTS( + !(context.enum_valid_values[i].empty() || context.enum_names[i].empty()), + "protobuf decode context: enum-as-string field requires non-empty metadata at field " + + std::to_string(i), + std::invalid_argument); + CUDF_EXPECTS( + context.enum_valid_values[i].size() == context.enum_names[i].size(), + "protobuf decode context: enum-as-string metadata mismatch at field " + std::to_string(i), + std::invalid_argument); auto const& ev = context.enum_valid_values[i]; for (size_t j = 1; j < ev.size(); ++j) { - if (ev[j] <= ev[j - 1]) { - CUDF_FAIL("protobuf decode context: enum_valid_values must be strictly sorted at field " + - std::to_string(i), - std::invalid_argument); - } + CUDF_EXPECTS( + ev[j] > ev[j - 1], + "protobuf decode context: enum_valid_values must be strictly sorted at field " + + std::to_string(i), + std::invalid_argument); } } } @@ -262,7 +336,14 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& rmm::device_async_resource_ref mr) { validate_decode_context(context); - auto const& schema = context.schema; + auto const& schema = context.schema; + auto const& default_ints = context.default_ints; + auto const& default_floats = context.default_floats; + auto const& default_bools = context.default_bools; + auto const& default_strings = context.default_strings; + auto const& enum_valid_values = context.enum_valid_values; + auto const& enum_names = context.enum_names; + bool fail_on_errors = context.fail_on_errors; CUDF_EXPECTS(binary_input.type().id() == cudf::type_id::LIST, "binary_input must be a LIST column"); cudf::lists_column_view const in_list(binary_input); @@ -307,8 +388,946 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); } + // Extract shared input data pointers (used by scalar, repeated, and nested sections) + cudf::lists_column_view const in_list_view(binary_input); + auto const* message_data = reinterpret_cast(in_list_view.child().data()); + auto const message_data_size = static_cast(in_list_view.child().size()); + auto const* list_offsets = in_list_view.offsets().data(); + + cudf::size_type base_offset = 0; + CUDF_CUDA_TRY(cudaMemcpyAsync( + &base_offset, list_offsets, sizeof(cudf::size_type), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + + // Copy schema to device + std::vector h_device_schema(num_fields); + for (int i = 0; i < num_fields; i++) { + h_device_schema[i] = device_nested_field_descriptor{schema[i]}; + } + + rmm::device_uvector d_schema(num_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_schema.data(), + h_device_schema.data(), + num_fields * sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + auto d_in = cudf::column_device_view::create(binary_input, stream); + // Identify repeated and nested fields at depth 0 + std::vector repeated_field_indices; + std::vector nested_field_indices; + std::vector scalar_field_indices; + + for (int i = 0; i < num_fields; i++) { + if (schema[i].parent_idx == -1) { // Top-level fields only + if (schema[i].is_repeated) { + repeated_field_indices.push_back(i); + } else if (schema[i].output_type == cudf::type_id::STRUCT) { + nested_field_indices.push_back(i); + } else { + scalar_field_indices.push_back(i); + } + } + } + + int num_repeated = static_cast(repeated_field_indices.size()); + int num_nested = static_cast(nested_field_indices.size()); + int num_scalar = static_cast(scalar_field_indices.size()); + + // Error flag + rmm::device_uvector d_error(1, stream, mr); + CUDF_CUDA_TRY(cudaMemsetAsync(d_error.data(), 0, sizeof(int), stream.value())); + auto error_message = [](int code) -> char const* { + switch (code) { + case ERR_BOUNDS: return "Protobuf decode error: message data out of bounds"; + case ERR_VARINT: return "Protobuf decode error: invalid or truncated varint"; + case ERR_FIELD_NUMBER: return "Protobuf decode error: invalid field number"; + case ERR_WIRE_TYPE: return "Protobuf decode error: unexpected wire type"; + case ERR_OVERFLOW: return "Protobuf decode error: length-delimited field overflows message"; + case ERR_FIELD_SIZE: return "Protobuf decode error: invalid field size"; + case ERR_SKIP: return "Protobuf decode error: unable to skip unknown field"; + case ERR_FIXED_LEN: + return "Protobuf decode error: invalid fixed-width or packed field length"; + case ERR_REQUIRED: return "Protobuf decode error: missing required field"; + case ERR_SCHEMA_TOO_LARGE: + return "Protobuf decode error: schema exceeds maximum supported repeated fields per kernel " + "(128)"; + case ERR_MISSING_ENUM_META: + return "Protobuf decode error: missing or mismatched enum metadata for enum-as-string " + "field"; + case ERR_REPEATED_COUNT_MISMATCH: + return "Protobuf decode error: repeated-field count/scan mismatch"; + default: return "Protobuf decode error: unknown error"; + } + }; + // PERMISSIVE-mode row nulling support. Unknown enum values and malformed rows should both + // surface as null structs instead of partially decoded data. + bool has_enum_fields = std::any_of( + enum_valid_values.begin(), enum_valid_values.end(), [](auto const& v) { return !v.empty(); }); + bool track_permissive_null_rows = !fail_on_errors; + rmm::device_uvector d_row_force_null(track_permissive_null_rows ? num_rows : 0, stream, mr); + if (track_permissive_null_rows) { + CUDF_CUDA_TRY( + cudaMemsetAsync(d_row_force_null.data(), 0, num_rows * sizeof(bool), stream.value())); + } + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + + // Allocate for counting repeated fields + rmm::device_uvector d_repeated_info( + num_repeated > 0 ? static_cast(num_rows) * num_repeated : 1, stream, mr); + rmm::device_uvector d_nested_locations( + num_nested > 0 ? static_cast(num_rows) * num_nested : 1, stream, mr); + + rmm::device_uvector d_repeated_indices(num_repeated > 0 ? num_repeated : 1, stream, mr); + rmm::device_uvector d_nested_indices(num_nested > 0 ? num_nested : 1, stream, mr); + + if (num_repeated > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_repeated_indices.data(), + repeated_field_indices.data(), + num_repeated * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + if (num_nested > 0) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_nested_indices.data(), + nested_field_indices.data(), + num_nested * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + + // Count repeated fields at depth 0 (with O(1) field_number lookup tables) + rmm::device_uvector d_fn_to_rep(0, stream, mr); + rmm::device_uvector d_fn_to_nested(0, stream, mr); + + if (num_repeated > 0 || num_nested > 0) { + auto h_fn_to_rep = + build_index_lookup_table(schema.data(), repeated_field_indices.data(), num_repeated); + auto h_fn_to_nested = + build_index_lookup_table(schema.data(), nested_field_indices.data(), num_nested); + + if (!h_fn_to_rep.empty()) { + d_fn_to_rep = rmm::device_uvector(h_fn_to_rep.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_fn_to_rep.data(), + h_fn_to_rep.data(), + h_fn_to_rep.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + if (!h_fn_to_nested.empty()) { + d_fn_to_nested = rmm::device_uvector(h_fn_to_nested.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_fn_to_nested.data(), + h_fn_to_nested.data(), + h_fn_to_nested.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + + launch_count_repeated_fields(*d_in, + d_schema.data(), + num_fields, + 0, + d_repeated_info.data(), + num_repeated, + d_repeated_indices.data(), + d_nested_locations.data(), + num_nested, + d_nested_indices.data(), + d_error.data(), + d_fn_to_rep.data(), + static_cast(d_fn_to_rep.size()), + d_fn_to_nested.data(), + static_cast(d_fn_to_nested.size()), + num_rows, + stream); + } + + // Store decoded columns by schema index for ordered assembly at the end. std::vector> column_map(num_fields); + // Process scalar fields using existing infrastructure + if (num_scalar > 0) { + std::vector h_field_descs(num_scalar); + for (int i = 0; i < num_scalar; i++) { + int schema_idx = scalar_field_indices[i]; + h_field_descs[i].field_number = schema[schema_idx].field_number; + h_field_descs[i].expected_wire_type = static_cast(schema[schema_idx].wire_type); + h_field_descs[i].is_repeated = false; + } + + rmm::device_uvector d_field_descs(num_scalar, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_descs.data(), + h_field_descs.data(), + num_scalar * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_locations( + static_cast(num_rows) * num_scalar, stream, mr); + + auto h_field_lookup = build_field_lookup_table(h_field_descs.data(), num_scalar); + rmm::device_uvector d_field_lookup(h_field_lookup.size(), stream, mr); + if (!h_field_lookup.empty()) { + CUDF_CUDA_TRY(cudaMemcpyAsync(d_field_lookup.data(), + h_field_lookup.data(), + h_field_lookup.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + + launch_scan_all_fields(*d_in, + d_field_descs.data(), + num_scalar, + h_field_lookup.empty() ? nullptr : d_field_lookup.data(), + static_cast(h_field_lookup.size()), + d_locations.data(), + d_error.data(), + track_permissive_null_rows ? d_row_force_null.data() : nullptr, + num_rows, + stream); + + // Required-field validation applies to all scalar leaves, not just top-level numerics. + maybe_check_required_fields(d_locations.data(), + scalar_field_indices, + schema, + num_rows, + binary_input.null_count() > 0 ? binary_input.null_mask() : nullptr, + binary_input.offset(), + nullptr, + track_permissive_null_rows ? d_row_force_null.data() : nullptr, + nullptr, + d_error.data(), + stream); + + // Batched scalar extraction: group non-special fixed-width fields by extraction + // category and extract all fields of each category with a single 2D kernel launch. + { + struct scalar_buf_pair { + rmm::device_uvector out_bytes; + rmm::device_uvector valid; + scalar_buf_pair(rmm::cuda_stream_view s, rmm::device_async_resource_ref m) + : out_bytes(0, s, m), valid(0, s, m) + { + } + }; + + // Classify each scalar field + // 0=I32, 1=U32, 2=I64, 3=U64, 4=BOOL, 5=I32zz, 6=I64zz, 7=F32, 8=F64, + // 9=I32fixed, 10=I64fixed, 11=fallback + constexpr int NUM_GROUPS = 12; + constexpr int GRP_FALLBACK = 11; + std::vector group_lists[NUM_GROUPS]; + + for (int i = 0; i < num_scalar; i++) { + int si = scalar_field_indices[i]; + auto tid = cudf::data_type{schema[si].output_type}.id(); + auto enc = schema[si].encoding; + bool zz = (enc == proto_encoding::ZIGZAG); + + // STRING, LIST, and enum-as-string go to per-field path + if (tid == cudf::type_id::STRING || tid == cudf::type_id::LIST) continue; + + bool is_fixed = (enc == proto_encoding::FIXED); + + // INT32 with enum validation goes to fallback + if (tid == cudf::type_id::INT32 && !zz && !is_fixed && + si < static_cast(enum_valid_values.size()) && !enum_valid_values[si].empty()) { + group_lists[GRP_FALLBACK].push_back(i); + continue; + } + + int g = GRP_FALLBACK; + if (tid == cudf::type_id::INT32 && is_fixed) { + g = 9; + } else if (tid == cudf::type_id::INT64 && is_fixed) { + g = 10; + } else if (tid == cudf::type_id::UINT32 && is_fixed) { + g = 9; + } else if (tid == cudf::type_id::UINT64 && is_fixed) { + g = 10; + } else if (tid == cudf::type_id::INT32 && !zz) { + g = 0; + } else if (tid == cudf::type_id::UINT32) { + g = 1; + } else if (tid == cudf::type_id::INT64 && !zz) { + g = 2; + } else if (tid == cudf::type_id::UINT64) { + g = 3; + } else if (tid == cudf::type_id::BOOL8) { + g = 4; + } else if (tid == cudf::type_id::INT32 && zz) { + g = 5; + } else if (tid == cudf::type_id::INT64 && zz) { + g = 6; + } else if (tid == cudf::type_id::FLOAT32) { + g = 7; + } else if (tid == cudf::type_id::FLOAT64) { + g = 8; + } + group_lists[g].push_back(i); + } + + // Helper: batch-extract one group using a 2D kernel, then build columns. + auto do_batch = [&](std::vector const& idxs, auto kernel_launcher) { + int nf = static_cast(idxs.size()); + if (nf == 0) return; + + std::vector> bufs; + bufs.reserve(nf); + std::vector h_descs(nf); + + for (int j = 0; j < nf; j++) { + int li = idxs[j]; + int si = scalar_field_indices[li]; + bool hd = schema[si].has_default_value; + auto& bp = *bufs.emplace_back(std::make_unique(stream, mr)); + bp.valid = rmm::device_uvector(std::max(1, num_rows), stream, mr); + // BOOL8 default comes from default_bools (converted to 0/1 int) + bool is_bool = (cudf::data_type{schema[si].output_type}.id() == cudf::type_id::BOOL8); + int64_t def_i = hd ? (is_bool ? (default_bools[si] ? 1 : 0) : default_ints[si]) : 0; + h_descs[j] = {li, nullptr, bp.valid.data(), hd, def_i, hd ? default_floats[si] : 0.0}; + } + + // kernel_launcher allocates out_bytes, sets h_descs[j].output, and launches kernel + kernel_launcher(nf, h_descs, bufs); + + // Build columns + for (int j = 0; j < nf; j++) { + int si = scalar_field_indices[idxs[j]]; + auto dt = cudf::data_type{schema[si].output_type}; + auto& bp = *bufs[j]; + auto [mask, null_count] = make_null_mask_from_valid(bp.valid, stream, mr); + column_map[si] = std::make_unique( + dt, num_rows, bp.out_bytes.release(), std::move(mask), null_count); + } + }; + + // Varint launcher for type T with zigzag ZZ + auto varint_launch = [&](int nf, + std::vector& h_descs, + std::vector>& bufs, + size_t elem_size, + auto kernel_fn) { + for (int j = 0; j < nf; j++) { + bufs[j]->out_bytes = rmm::device_uvector(num_rows * elem_size, stream, mr); + h_descs[j].output = bufs[j]->out_bytes.data(); + } + rmm::device_uvector d_descs(nf, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_descs.data(), + h_descs.data(), + nf * sizeof(h_descs[0]), + cudaMemcpyHostToDevice, + stream.value())); + dim3 grid((num_rows + threads - 1u) / threads, nf); + kernel_fn(grid, + threads, + stream.value(), + message_data, + list_offsets, + base_offset, + d_locations.data(), + num_scalar, + d_descs.data(), + nf, + num_rows, + d_error.data()); + }; + +// Dispatch groups 0-8 as batched +#define LAUNCH_VARINT_BATCH(GROUP, TYPE, ZZ) \ + do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ + varint_launch(nf, hd, bf, sizeof(TYPE), [](dim3 g, int t, cudaStream_t s, auto... args) { \ + extract_varint_batched_kernel<<>>(args...); \ + }); \ + }) + +#define LAUNCH_FIXED_BATCH(GROUP, TYPE, WT_VAL) \ + do_batch(group_lists[GROUP], [&](int nf, auto& hd, auto& bf) { \ + varint_launch(nf, hd, bf, sizeof(TYPE), [](dim3 g, int t, cudaStream_t s, auto... args) { \ + extract_fixed_batched_kernel<<>>(args...); \ + }); \ + }) + + LAUNCH_VARINT_BATCH(0, int32_t, false); + LAUNCH_VARINT_BATCH(1, uint32_t, false); + LAUNCH_VARINT_BATCH(2, int64_t, false); + LAUNCH_VARINT_BATCH(3, uint64_t, false); + LAUNCH_VARINT_BATCH(4, uint8_t, false); + LAUNCH_VARINT_BATCH(5, int32_t, true); + LAUNCH_VARINT_BATCH(6, int64_t, true); + LAUNCH_FIXED_BATCH(7, float, wire_type_value(proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH(8, double, wire_type_value(proto_wire_type::I64BIT)); + LAUNCH_FIXED_BATCH(9, int32_t, wire_type_value(proto_wire_type::I32BIT)); + LAUNCH_FIXED_BATCH(10, int64_t, wire_type_value(proto_wire_type::I64BIT)); + +#undef LAUNCH_VARINT_BATCH +#undef LAUNCH_FIXED_BATCH + + // Per-field fallback (INT32 with enum, etc.) + for (int i : group_lists[GRP_FALLBACK]) { + int schema_idx = scalar_field_indices[i]; + auto const field_meta = make_field_meta_view(context, schema_idx); + auto const dt = field_meta.output_type; + auto const enc = static_cast(field_meta.schema.encoding); + bool has_def = field_meta.schema.has_default_value; + top_level_location_provider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + column_map[schema_idx] = extract_typed_column(dt, + enc, + message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + has_def ? field_meta.default_int : 0, + has_def ? field_meta.default_float : 0.0, + has_def ? field_meta.default_bool : false, + field_meta.default_string, + schema_idx, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + stream, + mr); + } + } + + // Per-field extraction for STRING and LIST types + for (int i = 0; i < num_scalar; i++) { + int schema_idx = scalar_field_indices[i]; + auto const field_meta = make_field_meta_view(context, schema_idx); + auto const dt = field_meta.output_type; + if (dt.id() != cudf::type_id::STRING && dt.id() != cudf::type_id::LIST) { continue; } + auto const enc = field_meta.schema.encoding; + bool has_def = field_meta.schema.has_default_value; + + switch (dt.id()) { + case cudf::type_id::STRING: { + if (enc == proto_encoding::ENUM_STRING) { + // ENUM-as-string path: + // 1. Decode enum numeric value as INT32 varint. + // 2. Validate against enum_valid_values. + // 3. Convert INT32 -> UTF-8 enum name bytes. + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + int64_t def_int = has_def ? field_meta.default_int : 0; + top_level_location_provider loc_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out.data(), + valid.data(), + d_error.data(), + has_def, + def_int); + + if (schema_idx < static_cast(enum_valid_values.size()) && + schema_idx < static_cast(enum_names.size())) { + auto const& valid_enums = enum_valid_values[schema_idx]; + auto const& enum_name_bytes = enum_names[schema_idx]; + if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { + column_map[schema_idx] = build_enum_string_column( + out, valid, valid_enums, enum_name_bytes, d_row_force_null, num_rows, stream, mr); + } else { + // Missing enum metadata for enum-as-string field; mark as decode error. + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); + } + } else { + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); + } + } else { + // Regular protobuf STRING (length-delimited) + bool has_def_str = has_def; + auto const& def_str = field_meta.default_string; + top_level_location_provider len_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + top_level_location_provider copy_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_str] __device__( + cudf::size_type row) { + return locs[flat_index(static_cast(row), + static_cast(num_scalar), + static_cast(i))] + .offset >= 0 || + has_def_str; + }; + column_map[schema_idx] = extract_and_build_string_or_bytes_column(false, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_str, + def_str, + d_error, + stream, + mr); + } + break; + } + case cudf::type_id::LIST: { + // bytes (BinaryType) represented as LIST + bool has_def_bytes = has_def; + auto const& def_bytes = field_meta.default_string; + top_level_location_provider len_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + top_level_location_provider copy_provider{ + list_offsets, base_offset, d_locations.data(), i, num_scalar}; + auto valid_fn = [locs = d_locations.data(), i, num_scalar, has_def_bytes] __device__( + cudf::size_type row) { + return locs[flat_index(static_cast(row), + static_cast(num_scalar), + static_cast(i))] + .offset >= 0 || + has_def_bytes; + }; + column_map[schema_idx] = extract_and_build_string_or_bytes_column(true, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_bytes, + def_bytes, + d_error, + stream, + mr); + break; + } + default: + // For LIST (bytes) and other unsupported types, create placeholder columns + column_map[schema_idx] = make_null_column(dt, num_rows, stream, mr); + break; + } + } + } + + // Required top-level nested messages are tracked in d_nested_locations during the scan/count + // pass. + maybe_check_required_fields(d_nested_locations.data(), + nested_field_indices, + schema, + num_rows, + binary_input.null_count() > 0 ? binary_input.null_mask() : nullptr, + binary_input.offset(), + nullptr, + track_permissive_null_rows ? d_row_force_null.data() : nullptr, + nullptr, + d_error.data(), + stream); + + // Process repeated fields (three-phase: offsets → combined scan → build columns) + if (num_repeated > 0) { + // Phase A: Compute per-row offsets for each repeated field. + struct repeated_field_work { + int schema_idx; + int32_t total_count{0}; + rmm::device_uvector counts; + rmm::device_uvector offsets; + std::unique_ptr> occurrences; + + repeated_field_work(int si, + cudf::size_type n, + rmm::cuda_stream_view s, + rmm::device_async_resource_ref m) + : schema_idx(si), counts(n, s, m), offsets(n + 1, s, m) + { + } + }; + + std::vector> rep_work; + rep_work.reserve(num_repeated); + + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_field_indices[ri]; + auto& w = *rep_work.emplace_back( + std::make_unique(schema_idx, num_rows, stream, mr)); + + thrust::transform(rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + w.counts.data(), + extract_strided_count{d_repeated_info.data(), ri, num_repeated}); + + CUDF_CUDA_TRY(cudaMemsetAsync(w.offsets.data(), 0, sizeof(int32_t), stream.value())); + thrust::inclusive_scan( + rmm::exec_policy_nosync(stream), w.counts.begin(), w.counts.end(), w.offsets.data() + 1); + + CUDF_CUDA_TRY(cudaMemcpyAsync(&w.total_count, + w.offsets.data() + num_rows, + sizeof(int32_t), + cudaMemcpyDeviceToHost, + stream.value())); + } + stream.synchronize(); + + // Phase B: Allocate occurrence buffers and launch ONE combined scan kernel. + std::vector h_scan_descs; + h_scan_descs.reserve(num_repeated); + + for (auto& wp : rep_work) { + if (wp->total_count > 0) { + wp->occurrences = + std::make_unique>(wp->total_count, stream, mr); + h_scan_descs.push_back({schema[wp->schema_idx].field_number, + static_cast(schema[wp->schema_idx].wire_type), + wp->offsets.data(), + wp->occurrences->data()}); + } + } + + if (!h_scan_descs.empty()) { + rmm::device_uvector d_scan_descs(h_scan_descs.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_scan_descs.data(), + h_scan_descs.data(), + h_scan_descs.size() * sizeof(h_scan_descs[0]), + cudaMemcpyHostToDevice, + stream.value())); + + // Build field_number -> scan_desc_index lookup for the combined kernel + int max_scan_fn = 0; + for (auto const& sd : h_scan_descs) { + max_scan_fn = std::max(max_scan_fn, sd.field_number); + } + rmm::device_uvector d_fn_to_scan(0, stream, mr); + int fn_to_scan_size = 0; + if (max_scan_fn <= FIELD_LOOKUP_TABLE_MAX) { + std::vector h_fn_to_scan(max_scan_fn + 1, -1); + for (int i = 0; i < static_cast(h_scan_descs.size()); i++) { + h_fn_to_scan[h_scan_descs[i].field_number] = i; + } + d_fn_to_scan = rmm::device_uvector(h_fn_to_scan.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_fn_to_scan.data(), + h_fn_to_scan.data(), + h_fn_to_scan.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + fn_to_scan_size = static_cast(h_fn_to_scan.size()); + } + + launch_scan_all_repeated_occurrences(*d_in, + d_scan_descs.data(), + static_cast(h_scan_descs.size()), + d_error.data(), + fn_to_scan_size > 0 ? d_fn_to_scan.data() : nullptr, + fn_to_scan_size, + num_rows, + stream); + } + + // Phase C: Build columns per field. + for (int ri = 0; ri < num_repeated; ri++) { + auto& w = *rep_work[ri]; + int schema_idx = w.schema_idx; + auto element_type = cudf::data_type{schema[schema_idx].output_type}; + int32_t total_count = w.total_count; + auto& d_field_counts = w.counts; + + if (total_count > 0) { + auto& d_occurrences = *w.occurrences; + + // Build the appropriate column type based on element type + auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); + + // The output_type in schema is the LIST type, but we need element type + // For repeated int32, output_type should indicate the element is INT32 + switch (child_type_id) { + case cudf::type_id::INT32: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::INT64: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::UINT32: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::UINT64: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::FLOAT32: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::FLOAT64: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::BOOL8: + column_map[schema_idx] = + build_repeated_scalar_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + d_error, + stream, + mr); + break; + case cudf::type_id::STRING: { + auto const field_meta = make_field_meta_view(context, schema_idx); + auto enc = field_meta.schema.encoding; + if (enc == proto_encoding::ENUM_STRING) { + if (!field_meta.enum_valid_values.empty() && + field_meta.enum_valid_values.size() == field_meta.enum_names.size()) { + column_map[schema_idx] = + build_repeated_enum_string_column(binary_input, + message_data, + list_offsets, + base_offset, + d_field_counts, + d_occurrences, + total_count, + num_rows, + field_meta.enum_valid_values, + field_meta.enum_names, + d_row_force_null, + d_error, + stream, + mr); + } else { + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + column_map[schema_idx] = make_null_column( + cudf::data_type{schema[schema_idx].output_type}, num_rows, stream, mr); + } + } else { + column_map[schema_idx] = build_repeated_string_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + false, + d_error, + stream, + mr); + } + break; + } + case cudf::type_id::LIST: // bytes as LIST + column_map[schema_idx] = build_repeated_string_column(binary_input, + message_data, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + true, + d_error, + stream, + mr); + break; + case cudf::type_id::STRUCT: { + // Repeated message field - ArrayType(StructType) + auto child_field_indices = find_child_field_indices(schema, num_fields, schema_idx); + if (child_field_indices.empty()) { + auto empty_struct_child = + make_empty_struct_column_with_schema(schema, schema_idx, num_fields, stream, mr); + column_map[schema_idx] = make_null_list_column_with_child( + std::move(empty_struct_child), num_rows, stream, mr); + } else { + column_map[schema_idx] = build_repeated_struct_column(binary_input, + message_data, + message_data_size, + list_offsets, + base_offset, + h_device_schema[schema_idx], + d_field_counts, + d_occurrences, + total_count, + num_rows, + h_device_schema, + child_field_indices, + default_ints, + default_floats, + default_bools, + default_strings, + schema, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + stream, + mr); + } + break; + } + default: + // Unsupported element type - create null column + column_map[schema_idx] = make_null_list_column_with_child( + make_empty_column_safe(element_type, stream, mr), num_rows, stream, mr); + break; + } + } else { + // All rows have count=0 - create list of empty elements + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + + // Build appropriate empty child column + std::unique_ptr child_col; + auto child_type_id = static_cast(h_device_schema[schema_idx].output_type_id); + if (child_type_id == cudf::type_id::STRUCT) { + // Use helper to build empty struct with proper nested structure + child_col = + make_empty_struct_column_with_schema(schema, schema_idx, num_fields, stream, mr); + } else { + child_col = + make_empty_column_safe(cudf::data_type{schema[schema_idx].output_type}, stream, mr); + } + + auto const input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + column_map[schema_idx] = cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } else { + column_map[schema_idx] = cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + } + } + } + + // Process nested struct fields (Phase 2) + if (num_nested > 0) { + for (int ni = 0; ni < num_nested; ni++) { + int parent_schema_idx = nested_field_indices[ni]; + + // Find child fields of this nested message + auto child_field_indices = find_child_field_indices(schema, num_fields, parent_schema_idx); + + if (child_field_indices.empty()) { + // No child fields - create empty struct + column_map[parent_schema_idx] = make_null_column( + cudf::data_type{schema[parent_schema_idx].output_type}, num_rows, stream, mr); + continue; + } + + // Extract parent locations for this nested field directly on GPU + rmm::device_uvector d_parent_locs(num_rows, stream, mr); + launch_extract_strided_locations( + d_nested_locations.data(), ni, num_nested, d_parent_locs.data(), num_rows, stream); + + column_map[parent_schema_idx] = build_nested_struct_column(message_data, + message_data_size, + list_offsets, + base_offset, + d_parent_locs, + child_field_indices, + schema, + num_fields, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + num_rows, + stream, + mr, + nullptr, + 0, + false); + } + } + + // Assemble top_level_children in schema order (not processing order) std::vector> top_level_children; for (int i = 0; i < num_fields; i++) { if (schema[i].parent_idx == -1) { @@ -321,15 +1340,47 @@ std::unique_ptr decode_protobuf_to_struct(cudf::column_view const& } } - auto const input_null_count = binary_input.null_count(); - if (input_null_count > 0) { - auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); - return cudf::make_structs_column( - num_rows, std::move(top_level_children), input_null_count, std::move(null_mask), stream, mr); + CUDF_CUDA_TRY(cudaPeekAtLastError()); + int h_error = 0; + CUDF_CUDA_TRY( + cudaMemcpyAsync(&h_error, d_error.data(), sizeof(int), cudaMemcpyDeviceToHost, stream.value())); + stream.synchronize(); + if (h_error == ERR_SCHEMA_TOO_LARGE || h_error == ERR_REPEATED_COUNT_MISMATCH) { + throw cudf::logic_error(error_message(h_error)); + } + if (fail_on_errors && h_error != 0) throw cudf::logic_error(error_message(h_error)); + + // Build final struct with PERMISSIVE mode null mask for invalid enums + cudf::size_type struct_null_count = 0; + rmm::device_buffer struct_mask{0, stream, mr}; + + if (track_permissive_null_rows) { + auto [mask, null_count] = cudf::detail::valid_if( + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + [row_invalid = d_row_force_null.data()] __device__(cudf::size_type row) { + return !row_invalid[row]; + }, + stream, + mr); + struct_mask = std::move(mask); + struct_null_count = null_count; + } + + // cuDF child views do not automatically inherit parent nulls. Push PERMISSIVE invalid-enum + // nulls down into every top-level child, then recursively through nested STRUCT/LIST children, + // so callers that access backing grandchildren directly still observe logically-null rows. + if (track_permissive_null_rows && struct_null_count > 0) { + auto const* struct_mask_ptr = static_cast(struct_mask.data()); + for (auto& child : top_level_children) { + apply_parent_mask_to_row_aligned_column( + *child, struct_mask_ptr, struct_null_count, num_rows, stream, mr); + propagate_nulls_to_descendants(*child, stream, mr); + } } return cudf::make_structs_column( - num_rows, std::move(top_level_children), 0, rmm::device_buffer{}, stream, mr); + num_rows, std::move(top_level_children), struct_null_count, std::move(struct_mask), stream, mr); } } // namespace detail diff --git a/src/main/cpp/src/protobuf/protobuf_builders.cu b/src/main/cpp/src/protobuf/protobuf_builders.cu index 42420acedf..e4b149db2f 100644 --- a/src/main/cpp/src/protobuf/protobuf_builders.cu +++ b/src/main/cpp/src/protobuf/protobuf_builders.cu @@ -21,6 +21,115 @@ namespace spark_rapids_jni::protobuf::detail { +std::unique_ptr build_repeated_msg_child_varlen_column( + uint8_t const* message_data, + rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_locs, + rmm::device_uvector const& d_child_locs, + int child_idx, + int num_child_fields, + int total_count, + rmm::device_uvector& d_error, + bool as_bytes, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + if (total_count == 0) { + if (as_bytes) return make_empty_column_safe(cudf::data_type{cudf::type_id::LIST}, stream, mr); + return cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + } + + rmm::device_uvector d_lengths(total_count, stream, mr); + thrust::transform( + rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_lengths.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + auto const& loc = child_locs[flat_index( + static_cast(idx), static_cast(ncf), static_cast(ci))]; + return loc.offset >= 0 ? loc.length : 0; + }); + + auto [offsets_col, total_data] = cudf::strings::detail::make_offsets_child_column( + d_lengths.begin(), d_lengths.end(), stream, mr); + + rmm::device_uvector d_data(total_data, stream, mr); + rmm::device_uvector d_valid((total_count > 0 ? total_count : 1), stream, mr); + + thrust::transform( + rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(total_count), + d_valid.data(), + [child_locs = d_child_locs.data(), ci = child_idx, ncf = num_child_fields] __device__(int idx) { + return child_locs[flat_index(static_cast(idx), + static_cast(ncf), + static_cast(ci))] + .offset >= 0; + }); + + if (total_data > 0) { + repeated_msg_child_location_provider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + auto const* offsets_data = offsets_col->view().data(); + auto* chars_ptr = d_data.data(); + + auto src_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [message_data, loc_provider] __device__(int idx) -> void const* { + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + if (loc.offset < 0) return nullptr; + return static_cast(message_data + data_offset); + })); + auto dst_iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + return static_cast(chars_ptr + offsets_data[idx]); + })); + auto size_iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([loc_provider] __device__(int idx) -> size_t { + int32_t data_offset = 0; + auto loc = loc_provider.get(idx, data_offset); + if (loc.offset < 0) return 0; + return static_cast(loc.length); + })); + + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, total_count, stream.value()); + rmm::device_buffer temp_storage(temp_storage_bytes, stream, mr); + cub::DeviceMemcpy::Batched(temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + total_count, + stream.value()); + } + + auto [mask, null_count] = make_null_mask_from_valid(d_valid, stream, mr); + + if (as_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_data, + rmm::device_buffer(d_data.data(), total_data, stream, mr), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column( + total_count, std::move(offsets_col), std::move(bytes_child), null_count, std::move(mask)); + } + + return cudf::make_strings_column( + total_count, std::move(offsets_col), d_data.release(), null_count, std::move(mask)); +} + std::unique_ptr make_null_column(cudf::data_type dtype, cudf::size_type num_rows, rmm::cuda_stream_view stream, @@ -123,4 +232,1477 @@ std::unique_ptr make_empty_list_column(std::unique_ptr d_valid_enums; + rmm::device_uvector d_name_offsets; + rmm::device_uvector d_name_chars; +}; + +enum_string_lookup_tables make_enum_string_lookup_tables( + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto d_valid_enums = cudf::detail::make_device_uvector_async( + valid_enums, stream, cudf::get_current_device_resource_ref()); + + auto h_name_offsets = + cudf::detail::make_pinned_vector_async(valid_enums.size() + 1, stream); + std::fill(h_name_offsets.begin(), h_name_offsets.end(), 0); + int64_t total_name_chars = 0; + for (size_t k = 0; k < enum_name_bytes.size(); ++k) { + total_name_chars += static_cast(enum_name_bytes[k].size()); + CUDF_EXPECTS(total_name_chars <= std::numeric_limits::max(), + "Enum name data exceeds 2 GB limit"); + h_name_offsets[k + 1] = static_cast(total_name_chars); + } + + auto h_name_chars = cudf::detail::make_pinned_vector_async(total_name_chars, stream); + int32_t cursor = 0; + for (auto const& name : enum_name_bytes) { + if (!name.empty()) { + std::copy(name.data(), name.data() + name.size(), h_name_chars.data() + cursor); + cursor += static_cast(name.size()); + } + } + + auto d_name_offsets = cudf::detail::make_device_uvector_async( + h_name_offsets, stream, cudf::get_current_device_resource_ref()); + + auto d_name_chars = [&]() { + if (total_name_chars > 0) { + return cudf::detail::make_device_uvector_async( + h_name_chars, stream, cudf::get_current_device_resource_ref()); + } + return rmm::device_uvector(0, stream, mr); + }(); + + return {std::move(d_valid_enums), std::move(d_name_offsets), std::move(d_name_chars)}; +} + +std::unique_ptr build_enum_string_values_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + enum_string_lookup_tables const& lookup, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + rmm::device_uvector lengths(num_rows, stream, mr); + launch_compute_enum_string_lengths(enum_values.data(), + valid.data(), + lookup.d_valid_enums.data(), + lookup.d_name_offsets.data(), + static_cast(lookup.d_valid_enums.size()), + lengths.data(), + num_rows, + stream); + + auto [offsets_col, total_chars] = + cudf::strings::detail::make_offsets_child_column(lengths.begin(), lengths.end(), stream, mr); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + launch_copy_enum_string_chars(enum_values.data(), + valid.data(), + lookup.d_valid_enums.data(), + lookup.d_name_offsets.data(), + lookup.d_name_chars.data(), + static_cast(lookup.d_valid_enums.size()), + offsets_col->view().data(), + chars.data(), + num_rows, + stream); + } + + auto [mask, null_count] = make_null_mask_from_valid(valid, stream, mr); + return cudf::make_strings_column( + num_rows, std::move(offsets_col), chars.release(), null_count, std::move(mask)); +} + +std::unique_ptr build_enum_string_column( + rmm::device_uvector& enum_values, + rmm::device_uvector& valid, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + bool propagate_invalid_rows) +{ + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); + rmm::device_uvector d_item_has_invalid_enum(num_rows, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), + d_item_has_invalid_enum.begin(), + d_item_has_invalid_enum.end(), + false); + + launch_validate_enum_values(enum_values.data(), + valid.data(), + d_item_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + num_rows, + stream); + propagate_invalid_enum_flags_to_rows(d_item_has_invalid_enum, + d_row_force_null, + num_rows, + top_row_indices, + propagate_invalid_rows, + stream); + return build_enum_string_values_column(enum_values, valid, lookup, num_rows, stream, mr); +} + +std::unique_ptr build_repeated_msg_child_enum_string_column( + uint8_t const* message_data, + rmm::device_uvector const& d_msg_row_offsets, + rmm::device_uvector const& d_msg_locs, + rmm::device_uvector const& d_child_locs, + int child_idx, + int num_child_fields, + int total_count, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + int32_t const* top_row_indices, + bool propagate_invalid_rows, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); + + rmm::device_uvector enum_values(total_count, stream, mr); + rmm::device_uvector valid((total_count > 0 ? total_count : 1), stream, mr); + repeated_msg_child_location_provider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + child_idx, + num_child_fields}; + extract_varint_kernel + <<>>(message_data, + loc_provider, + total_count, + enum_values.data(), + valid.data(), + d_error.data(), + false, + 0); + + rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), + d_elem_has_invalid_enum.begin(), + d_elem_has_invalid_enum.end(), + false); + launch_validate_enum_values(enum_values.data(), + valid.data(), + d_elem_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + total_count, + stream); + propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, + d_row_force_null, + total_count, + top_row_indices, + propagate_invalid_rows, + stream); + return build_enum_string_values_column(enum_values, valid, lookup, total_count, stream, mr); +} + +std::unique_ptr build_repeated_enum_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + cudf::detail::host_vector const& valid_enums, + std::vector> const& enum_name_bytes, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const rep_blocks = + static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + auto lookup = make_enum_string_lookup_tables(valid_enums, enum_name_bytes, stream, mr); + + // 1. Extract enum integer values from occurrences + rmm::device_uvector enum_ints(total_count, stream, mr); + rmm::device_uvector elem_valid(total_count, stream, mr); + repeated_location_provider rep_loc{list_offsets, base_offset, d_occurrences.data()}; + extract_varint_kernel + <<>>(message_data, + rep_loc, + total_count, + enum_ints.data(), + elem_valid.data(), + d_error.data(), + false, + 0); + + // 2. Validate enum values — mark invalid as false in elem_valid + // (elem_valid was already populated by extract_varint_kernel: true for success, false for + // failure) + rmm::device_uvector d_elem_has_invalid_enum(total_count, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), + d_elem_has_invalid_enum.begin(), + d_elem_has_invalid_enum.end(), + false); + launch_validate_enum_values(enum_ints.data(), + elem_valid.data(), + d_elem_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(valid_enums.size()), + total_count, + stream); + + rmm::device_uvector d_top_row_indices(total_count, stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + d_occurrences.begin(), + d_occurrences.end(), + d_top_row_indices.begin(), + [] __device__(repeated_occurrence const& occ) { return occ.row_idx; }); + propagate_invalid_enum_flags_to_rows( + d_elem_has_invalid_enum, d_row_force_null, total_count, d_top_row_indices.data(), true, stream); + + auto child_col = + build_enum_string_values_column(enum_ints, elem_valid, lookup, total_count, stream, mr); + + // Build the final LIST column from the per-row counts and decoded child strings. + rmm::device_uvector lo(num_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy_nosync(stream), d_field_counts.begin(), d_field_counts.end(), lo.begin(), 0); + int32_t tc_i32 = static_cast(total_count); + thrust::fill_n(rmm::exec_policy_nosync(stream), lo.data() + num_rows, 1, tc_i32); + + auto list_offs_col = std::make_unique( + cudf::data_type{cudf::type_id::INT32}, num_rows + 1, lo.release(), rmm::device_buffer{}, 0); + + auto input_null_count = binary_input.null_count(); + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(list_offs_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } + return cudf::make_lists_column( + num_rows, std::move(list_offs_col), std::move(child_col), 0, rmm::device_buffer{}); +} + +std::unique_ptr build_repeated_string_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + bool is_bytes, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + + if (total_count == 0) { + // All rows have count=0, but we still need to check input nulls + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + auto child_col = is_bytes ? make_empty_column_safe( + cudf::data_type{cudf::type_id::LIST}, stream, mr) // LIST + : cudf::make_empty_column(cudf::data_type{cudf::type_id::STRING}); + + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } else { + // No input nulls, all rows get empty arrays [] + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); + + int32_t total_count_i32 = static_cast(total_count); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); + + // Extract string lengths from occurrences + rmm::device_uvector str_lengths(total_count, stream, mr); + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + repeated_location_provider loc_provider{list_offsets, base_offset, d_occurrences.data()}; + extract_lengths_kernel + <<>>(loc_provider, total_count, str_lengths.data()); + + auto [str_offsets_col, total_chars] = cudf::strings::detail::make_offsets_child_column( + str_lengths.begin(), str_lengths.end(), stream, mr); + + rmm::device_uvector chars(total_chars, stream, mr); + if (total_chars > 0) { + repeated_location_provider copy_provider{list_offsets, base_offset, d_occurrences.data()}; + auto const* offsets_data = str_offsets_col->view().data(); + auto* chars_ptr = chars.data(); + + auto src_iter = cudf::detail::make_counting_transform_iterator( + 0, + cuda::proclaim_return_type( + [message_data, copy_provider] __device__(int idx) -> void const* { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) return nullptr; + return static_cast(message_data + data_offset); + })); + auto dst_iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([chars_ptr, offsets_data] __device__(int idx) -> void* { + return static_cast(chars_ptr + offsets_data[idx]); + })); + auto size_iter = cudf::detail::make_counting_transform_iterator( + 0, cuda::proclaim_return_type([copy_provider] __device__(int idx) -> size_t { + int32_t data_offset = 0; + auto loc = copy_provider.get(idx, data_offset); + if (loc.offset < 0) return 0; + return static_cast(loc.length); + })); + + size_t temp_storage_bytes = 0; + cub::DeviceMemcpy::Batched( + nullptr, temp_storage_bytes, src_iter, dst_iter, size_iter, total_count, stream.value()); + rmm::device_buffer temp_storage(temp_storage_bytes, stream, mr); + cub::DeviceMemcpy::Batched(temp_storage.data(), + temp_storage_bytes, + src_iter, + dst_iter, + size_iter, + total_count, + stream.value()); + } + + std::unique_ptr child_col; + if (is_bytes) { + auto bytes_child = + std::make_unique(cudf::data_type{cudf::type_id::UINT8}, + total_chars, + rmm::device_buffer(chars.data(), total_chars, stream, mr), + rmm::device_buffer{}, + 0); + child_col = cudf::make_lists_column( + total_count, std::move(str_offsets_col), std::move(bytes_child), 0, rmm::device_buffer{}); + } else { + child_col = cudf::make_strings_column( + total_count, std::move(str_offsets_col), chars.release(), 0, rmm::device_buffer{}); + } + + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + + // Only rows where INPUT is null should produce null output + // Rows with valid input but count=0 should produce empty array [] + if (input_null_count > 0) { + // Copy input null mask - only input nulls produce output nulls + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(child_col), + input_null_count, + std::move(null_mask)); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(child_col), 0, rmm::device_buffer{}); +} + +// Forward declaration -- build_nested_struct_column is defined after build_repeated_struct_column +// but the latter's STRUCT-child case needs to call it. +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows); + +// Forward declaration -- build_repeated_child_list_column is defined after +// build_nested_struct_column but both build_repeated_struct_column and build_nested_struct_column +// need to call it. +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows); + +std::unique_ptr build_repeated_struct_column( + cudf::column_view const& binary_input, + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + device_nested_field_descriptor const& field_desc, + rmm::device_uvector const& d_field_counts, + rmm::device_uvector& d_occurrences, + int total_count, + int num_rows, + std::vector const& h_device_schema, + std::vector const& child_field_indices, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector const& schema, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error_top, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr) +{ + auto const input_null_count = binary_input.null_count(); + int num_child_fields = static_cast(child_field_indices.size()); + + if (total_count == 0 || num_child_fields == 0) { + // All rows have count=0 or no child fields - return list of empty structs + rmm::device_uvector offsets(num_rows + 1, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), offsets.begin(), offsets.end(), 0); + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + offsets.release(), + rmm::device_buffer{}, + 0); + + // Build empty struct child column with proper nested structure + int num_schema_fields = static_cast(h_device_schema.size()); + std::vector> empty_struct_children; + for (int child_schema_idx : child_field_indices) { + auto child_type = cudf::data_type{schema[child_schema_idx].output_type}; + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = make_empty_struct_column_with_schema( + h_device_schema, child_schema_idx, num_schema_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + if (h_device_schema[child_schema_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + empty_struct_children.push_back(std::move(child_col)); + } + auto empty_struct = cudf::make_structs_column( + 0, std::move(empty_struct_children), 0, rmm::device_buffer{}, stream, mr); + + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(empty_struct), + input_null_count, + std::move(null_mask)); + } else { + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(empty_struct), 0, rmm::device_buffer{}); + } + } + + rmm::device_uvector list_offs(num_rows + 1, stream, mr); + thrust::exclusive_scan(rmm::exec_policy_nosync(stream), + d_field_counts.begin(), + d_field_counts.end(), + list_offs.begin(), + 0); + + int32_t total_count_i32 = static_cast(total_count); + thrust::fill_n(rmm::exec_policy_nosync(stream), list_offs.data() + num_rows, 1, total_count_i32); + + // Build child field descriptors for scanning within each message occurrence + std::vector h_child_descs(num_child_fields); + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + h_child_descs[ci].field_number = h_device_schema[child_schema_idx].field_number; + h_child_descs[ci].expected_wire_type = h_device_schema[child_schema_idx].wire_type; + h_child_descs[ci].is_repeated = h_device_schema[child_schema_idx].is_repeated; + } + rmm::device_uvector d_child_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_descs.data(), + h_child_descs.data(), + num_child_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + auto h_child_lookup = build_field_lookup_table(h_child_descs.data(), num_child_fields); + rmm::device_uvector d_child_lookup(0, stream, mr); + if (!h_child_lookup.empty()) { + d_child_lookup = rmm::device_uvector(h_child_lookup.size(), stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_lookup.data(), + h_child_lookup.data(), + h_child_lookup.size() * sizeof(int), + cudaMemcpyHostToDevice, + stream.value())); + } + + // For each occurrence, we need to scan for child fields + // Create "virtual" parent locations from the occurrences using GPU kernel + // This replaces the host-side loop with D->H->D copy pattern (critical performance fix!) + rmm::device_uvector d_msg_locs(total_count, stream, mr); + rmm::device_uvector d_msg_row_offsets(total_count, stream, mr); + launch_compute_msg_locations_from_occurrences(d_occurrences.data(), + list_offsets, + base_offset, + d_msg_locs.data(), + d_msg_row_offsets.data(), + total_count, + d_error_top.data(), + stream); + rmm::device_uvector d_top_row_indices(total_count, stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + d_occurrences.data(), + d_occurrences.end(), + d_top_row_indices.data(), + [] __device__(repeated_occurrence const& occ) { return occ.row_idx; }); + + // Scan for child fields within each message occurrence + rmm::device_uvector d_child_locs(total_count * num_child_fields, stream, mr); + // Reuse top-level error flag so failfast can observe nested repeated-message failures. + auto& d_error = d_error_top; + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((total_count + threads - 1u) / threads); + + // Use a custom kernel to scan child fields within message occurrences + // This is similar to scan_nested_message_fields_kernel but operates on occurrences + launch_scan_repeated_message_children(message_data, + message_data_size, + d_msg_row_offsets.data(), + d_msg_locs.data(), + total_count, + d_child_descs.data(), + num_child_fields, + d_child_locs.data(), + d_error.data(), + h_child_lookup.empty() ? nullptr : d_child_lookup.data(), + static_cast(d_child_lookup.size()), + stream); + + // Enforce proto2 required semantics for fields inside each repeated message occurrence. + maybe_check_required_fields(d_child_locs.data(), + child_field_indices, + schema, + total_count, + nullptr, + 0, + nullptr, + d_row_force_null.size() > 0 ? d_row_force_null.data() : nullptr, + d_top_row_indices.data(), + d_error.data(), + stream); + + // Note: We no longer need to copy child_locs to host because: + // 1. All scalar extraction kernels access d_child_locs directly on device + // 2. String extraction uses GPU kernels + // 3. Nested struct locations are computed on GPU via compute_nested_struct_locations_kernel + + // Extract child field values - build one column per child field + std::vector> struct_children; + int num_schema_fields = static_cast(h_device_schema.size()); + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = cudf::data_type{schema[child_schema_idx].output_type}; + auto const enc = h_device_schema[child_schema_idx].encoding; + bool has_def = h_device_schema[child_schema_idx].has_default_value; + bool child_is_repeated = h_device_schema[child_schema_idx].is_repeated; + + if (child_is_repeated) { + struct_children.push_back(build_repeated_child_list_column(message_data, + message_data_size, + d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + total_count, + child_schema_idx, + schema, + num_schema_fields, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_force_null, + d_error_top, + stream, + mr, + d_top_row_indices.data(), + 1, + false)); + continue; + } + + switch (dt.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: { + repeated_msg_child_location_provider loc_provider{d_msg_row_offsets.data(), + 0, + d_msg_locs.data(), + d_child_locs.data(), + ci, + num_child_fields}; + struct_children.push_back( + extract_typed_column(dt, + enc, + message_data, + loc_provider, + total_count, + blocks, + threads, + has_def, + has_def ? default_ints[child_schema_idx] : 0, + has_def ? default_floats[child_schema_idx] : 0.0, + has_def ? default_bools[child_schema_idx] : false, + default_strings[child_schema_idx], + child_schema_idx, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + stream, + mr, + d_top_row_indices.data(), + false)); + break; + } + case cudf::type_id::STRING: { + if (enc == encoding_value(proto_encoding::ENUM_STRING)) { + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size()) && + !enum_valid_values[child_schema_idx].empty() && + enum_valid_values[child_schema_idx].size() == enum_names[child_schema_idx].size()) { + struct_children.push_back( + build_repeated_msg_child_enum_string_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + enum_valid_values[child_schema_idx], + enum_names[child_schema_idx], + d_row_force_null, + d_top_row_indices.data(), + false, + d_error, + stream, + mr)); + } else { + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + struct_children.push_back(make_null_column(dt, total_count, stream, mr)); + } + } else { + struct_children.push_back(build_repeated_msg_child_varlen_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + false, + stream, + mr)); + } + break; + } + case cudf::type_id::LIST: { + struct_children.push_back(build_repeated_msg_child_varlen_column(message_data, + d_msg_row_offsets, + d_msg_locs, + d_child_locs, + ci, + num_child_fields, + total_count, + d_error, + true, + stream, + mr)); + break; + } + case cudf::type_id::STRUCT: { + // Nested struct inside repeated message - use recursive build_nested_struct_column + int num_schema_fields = static_cast(h_device_schema.size()); + auto grandchild_indices = + find_child_field_indices(h_device_schema, num_schema_fields, child_schema_idx); + + if (grandchild_indices.empty()) { + struct_children.push_back( + cudf::make_structs_column(total_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr)); + } else { + // Compute virtual parent locations for each occurrence's nested struct child + rmm::device_uvector d_nested_locs(total_count, stream, mr); + rmm::device_uvector d_nested_row_offsets(total_count, stream, mr); + launch_compute_nested_struct_locations(d_child_locs.data(), + d_msg_locs.data(), + d_msg_row_offsets.data(), + ci, + num_child_fields, + d_nested_locs.data(), + d_nested_row_offsets.data(), + total_count, + d_error_top.data(), + stream); + + struct_children.push_back(build_nested_struct_column(message_data, + message_data_size, + d_nested_row_offsets.data(), + 0, + d_nested_locs, + grandchild_indices, + schema, + num_schema_fields, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_force_null, + d_error_top, + total_count, + stream, + mr, + d_top_row_indices.data(), + 0, + false)); + } + break; + } + default: + // Unsupported child type - create null column + struct_children.push_back(make_null_column(dt, total_count, stream, mr)); + break; + } + } + + // Build the struct column from child columns + auto struct_col = cudf::make_structs_column( + total_count, std::move(struct_children), 0, rmm::device_buffer{}, stream, mr); + + // Build the list offsets column + auto offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + + // Build the final LIST column + if (input_null_count > 0) { + auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); + return cudf::make_lists_column(num_rows, + std::move(offsets_col), + std::move(struct_col), + input_null_count, + std::move(null_mask)); + } + + return cudf::make_lists_column( + num_rows, std::move(offsets_col), std::move(struct_col), 0, rmm::device_buffer{}); +} + +std::unique_ptr build_nested_struct_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + rmm::device_uvector const& d_parent_locs, + std::vector const& child_field_indices, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + int num_rows, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows) +{ + CUDF_EXPECTS(depth < MAX_NESTING_DEPTH, + "Nested protobuf struct depth exceeds supported decode recursion limit"); + + if (num_rows == 0) { + std::vector> empty_children; + for (int child_schema_idx : child_field_indices) { + auto child_type = cudf::data_type{schema[child_schema_idx].output_type}; + std::unique_ptr child_col; + if (child_type.id() == cudf::type_id::STRUCT) { + child_col = + make_empty_struct_column_with_schema(schema, child_schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(child_type, stream, mr); + } + if (schema[child_schema_idx].is_repeated) { + child_col = make_empty_list_column(std::move(child_col), stream, mr); + } + empty_children.push_back(std::move(child_col)); + } + return cudf::make_structs_column( + 0, std::move(empty_children), 0, rmm::device_buffer{}, stream, mr); + } + + auto const threads = THREADS_PER_BLOCK; + auto const blocks = static_cast((num_rows + threads - 1u) / threads); + int num_child_fields = static_cast(child_field_indices.size()); + + std::vector h_child_field_descs(num_child_fields); + for (int i = 0; i < num_child_fields; i++) { + int child_idx = child_field_indices[i]; + h_child_field_descs[i].field_number = schema[child_idx].field_number; + h_child_field_descs[i].expected_wire_type = static_cast(schema[child_idx].wire_type); + h_child_field_descs[i].is_repeated = schema[child_idx].is_repeated; + } + + rmm::device_uvector d_child_field_descs(num_child_fields, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_child_field_descs.data(), + h_child_field_descs.data(), + num_child_fields * sizeof(field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + rmm::device_uvector d_child_locations( + static_cast(num_rows) * num_child_fields, stream, mr); + launch_scan_nested_message_fields(message_data, + message_data_size, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + d_child_field_descs.data(), + num_child_fields, + d_child_locations.data(), + d_error.data(), + stream); + + // Enforce proto2 required semantics for direct children of this nested message. + maybe_check_required_fields(d_child_locations.data(), + child_field_indices, + schema, + num_rows, + nullptr, + 0, + d_parent_locs.data(), + d_row_force_null.size() > 0 ? d_row_force_null.data() : nullptr, + top_row_indices, + d_error.data(), + stream); + + std::vector> struct_children; + for (int ci = 0; ci < num_child_fields; ci++) { + int child_schema_idx = child_field_indices[ci]; + auto const dt = cudf::data_type{schema[child_schema_idx].output_type}; + auto const enc = static_cast(schema[child_schema_idx].encoding); + bool has_def = schema[child_schema_idx].has_default_value; + bool is_repeated = schema[child_schema_idx].is_repeated; + + if (is_repeated) { + struct_children.push_back(build_repeated_child_list_column(message_data, + message_data_size, + list_offsets, + base_offset, + d_parent_locs.data(), + num_rows, + child_schema_idx, + schema, + num_fields, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + stream, + mr, + top_row_indices, + depth, + propagate_invalid_rows)); + continue; + } + + switch (dt.id()) { + case cudf::type_id::BOOL8: + case cudf::type_id::INT32: + case cudf::type_id::UINT32: + case cudf::type_id::INT64: + case cudf::type_id::UINT64: + case cudf::type_id::FLOAT32: + case cudf::type_id::FLOAT64: { + nested_location_provider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + struct_children.push_back( + extract_typed_column(dt, + enc, + message_data, + loc_provider, + num_rows, + blocks, + threads, + has_def, + has_def ? default_ints[child_schema_idx] : 0, + has_def ? default_floats[child_schema_idx] : 0.0, + has_def ? default_bools[child_schema_idx] : false, + default_strings[child_schema_idx], + child_schema_idx, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + stream, + mr, + top_row_indices, + propagate_invalid_rows)); + break; + } + case cudf::type_id::STRING: { + if (enc == encoding_value(proto_encoding::ENUM_STRING)) { + rmm::device_uvector out(num_rows, stream, mr); + rmm::device_uvector valid((num_rows > 0 ? num_rows : 1), stream, mr); + int64_t def_int = has_def ? default_ints[child_schema_idx] : 0; + nested_location_provider loc_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + extract_varint_kernel + <<>>(message_data, + loc_provider, + num_rows, + out.data(), + valid.data(), + d_error.data(), + has_def, + def_int); + + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size())) { + auto const& valid_enums = enum_valid_values[child_schema_idx]; + auto const& enum_name_bytes = enum_names[child_schema_idx]; + if (!valid_enums.empty() && valid_enums.size() == enum_name_bytes.size()) { + struct_children.push_back(build_enum_string_column(out, + valid, + valid_enums, + enum_name_bytes, + d_row_force_null, + num_rows, + stream, + mr, + top_row_indices, + propagate_invalid_rows)); + } else { + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + } + } else { + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + } + } else { + bool has_def_str = has_def; + auto const& def_str = default_strings[child_schema_idx]; + nested_location_provider len_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + nested_location_provider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + auto valid_fn = [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_str] __device__(cudf::size_type row) { + return (plocs[row].offset >= 0 && + flocs[flat_index(static_cast(row), + static_cast(num_child_fields), + static_cast(ci))] + .offset >= 0) || + has_def_str; + }; + struct_children.push_back(extract_and_build_string_or_bytes_column(false, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_str, + def_str, + d_error, + stream, + mr)); + } + break; + } + case cudf::type_id::LIST: { + // bytes (BinaryType) represented as LIST + bool has_def_bytes = has_def; + auto const& def_bytes = default_strings[child_schema_idx]; + nested_location_provider len_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + nested_location_provider copy_provider{list_offsets, + base_offset, + d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields}; + auto valid_fn = [plocs = d_parent_locs.data(), + flocs = d_child_locations.data(), + ci, + num_child_fields, + has_def_bytes] __device__(cudf::size_type row) { + return (plocs[row].offset >= 0 && flocs[flat_index(static_cast(row), + static_cast(num_child_fields), + static_cast(ci))] + .offset >= 0) || + has_def_bytes; + }; + struct_children.push_back(extract_and_build_string_or_bytes_column(true, + message_data, + num_rows, + len_provider, + copy_provider, + valid_fn, + has_def_bytes, + def_bytes, + d_error, + stream, + mr)); + break; + } + case cudf::type_id::STRUCT: { + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); + break; + } + rmm::device_uvector d_gc_parent(num_rows, stream, mr); + launch_compute_grandchild_parent_locations(d_parent_locs.data(), + d_child_locations.data(), + ci, + num_child_fields, + d_gc_parent.data(), + num_rows, + d_error.data(), + stream); + struct_children.push_back(build_nested_struct_column(message_data, + message_data_size, + list_offsets, + base_offset, + d_gc_parent, + gc_indices, + schema, + num_fields, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + num_rows, + stream, + mr, + top_row_indices, + depth + 1, + propagate_invalid_rows)); + break; + } + default: struct_children.push_back(make_null_column(dt, num_rows, stream, mr)); break; + } + } + + rmm::device_uvector struct_valid((num_rows > 0 ? num_rows : 1), stream, mr); + thrust::transform( + rmm::exec_policy_nosync(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + struct_valid.data(), + [plocs = d_parent_locs.data()] __device__(auto row) { return plocs[row].offset >= 0; }); + auto [struct_mask, struct_null_count] = make_null_mask_from_valid(struct_valid, stream, mr); + return cudf::make_structs_column( + num_rows, std::move(struct_children), struct_null_count, std::move(struct_mask), stream, mr); +} + +/** + * Build a LIST column for a repeated child field inside a parent message. + * Shared between build_nested_struct_column and build_repeated_struct_column. + */ +std::unique_ptr build_repeated_child_list_column( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_parent_rows, + int child_schema_idx, + std::vector const& schema, + int num_fields, + std::vector const& default_ints, + std::vector const& default_floats, + std::vector const& default_bools, + std::vector> const& default_strings, + std::vector> const& enum_valid_values, + std::vector>> const& enum_names, + rmm::device_uvector& d_row_force_null, + rmm::device_uvector& d_error, + rmm::cuda_stream_view stream, + rmm::device_async_resource_ref mr, + int32_t const* top_row_indices, + int depth, + bool propagate_invalid_rows) +{ + auto elem_type_id = schema[child_schema_idx].output_type; + rmm::device_uvector d_rep_info(num_parent_rows, stream, mr); + + std::vector rep_indices = {0}; + rmm::device_uvector d_rep_indices(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync( + d_rep_indices.data(), rep_indices.data(), sizeof(int), cudaMemcpyHostToDevice, stream.value())); + + device_nested_field_descriptor rep_desc; + rep_desc.field_number = schema[child_schema_idx].field_number; + rep_desc.wire_type = static_cast(schema[child_schema_idx].wire_type); + rep_desc.output_type_id = static_cast(schema[child_schema_idx].output_type); + rep_desc.is_repeated = true; + rep_desc.parent_idx = -1; + rep_desc.depth = 0; + rep_desc.encoding = 0; + rep_desc.is_required = false; + rep_desc.has_default_value = false; + CUDF_EXPECTS(schema[child_schema_idx].is_repeated, + "count_repeated_in_nested_kernel launch requires repeated child schema"); + CUDF_EXPECTS(rep_desc.depth == 0, + "count_repeated_in_nested_kernel launch requires pre-filtered local depth 0"); + + std::vector h_rep_schema = {rep_desc}; + rmm::device_uvector d_rep_schema(1, stream, mr); + CUDF_CUDA_TRY(cudaMemcpyAsync(d_rep_schema.data(), + h_rep_schema.data(), + sizeof(device_nested_field_descriptor), + cudaMemcpyHostToDevice, + stream.value())); + + launch_count_repeated_in_nested(message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + 1, + d_rep_info.data(), + 1, + d_rep_indices.data(), + d_error.data(), + stream); + + rmm::device_uvector d_rep_counts(num_parent_rows, stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + d_rep_info.data(), + d_rep_info.end(), + d_rep_counts.data(), + [] __device__(repeated_field_info const& info) { return info.count; }); + int total_rep_count = + thrust::reduce(rmm::exec_policy_nosync(stream), d_rep_counts.data(), d_rep_counts.end(), 0); + + if (total_rep_count == 0) { + rmm::device_uvector list_offsets_vec(num_parent_rows + 1, stream, mr); + thrust::fill( + rmm::exec_policy_nosync(stream), list_offsets_vec.data(), list_offsets_vec.end(), 0); + auto list_offsets_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offsets_vec.release(), + rmm::device_buffer{}, + 0); + std::unique_ptr child_col; + if (elem_type_id == cudf::type_id::STRUCT) { + child_col = + make_empty_struct_column_with_schema(schema, child_schema_idx, num_fields, stream, mr); + } else { + child_col = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + return cudf::make_lists_column( + num_parent_rows, std::move(list_offsets_col), std::move(child_col), 0, rmm::device_buffer{}); + } + + rmm::device_uvector list_offs(num_parent_rows + 1, stream, mr); + thrust::exclusive_scan( + rmm::exec_policy_nosync(stream), d_rep_counts.data(), d_rep_counts.end(), list_offs.begin(), 0); + thrust::fill_n( + rmm::exec_policy_nosync(stream), list_offs.data() + num_parent_rows, 1, total_rep_count); + + rmm::device_uvector d_rep_occs(total_rep_count, stream, mr); + launch_scan_repeated_in_nested(message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_parent_rows, + d_rep_schema.data(), + list_offs.data(), + d_rep_indices.data(), + d_rep_occs.data(), + d_error.data(), + stream); + + rmm::device_uvector d_rep_top_row_indices(total_rep_count, stream, mr); + thrust::transform(rmm::exec_policy_nosync(stream), + d_rep_occs.begin(), + d_rep_occs.end(), + d_rep_top_row_indices.begin(), + [top_row_indices] __device__(repeated_occurrence const& occ) { + return top_row_indices != nullptr ? top_row_indices[occ.row_idx] + : occ.row_idx; + }); + + std::unique_ptr child_values; + auto const rep_blocks = + static_cast((total_rep_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + nested_repeated_location_provider nr_loc{ + row_offsets, base_offset, parent_locs, d_rep_occs.data()}; + + if (elem_type_id == cudf::type_id::BOOL8 || elem_type_id == cudf::type_id::INT32 || + elem_type_id == cudf::type_id::UINT32 || elem_type_id == cudf::type_id::INT64 || + elem_type_id == cudf::type_id::UINT64 || elem_type_id == cudf::type_id::FLOAT32 || + elem_type_id == cudf::type_id::FLOAT64) { + child_values = extract_typed_column(cudf::data_type{elem_type_id}, + static_cast(schema[child_schema_idx].encoding), + message_data, + nr_loc, + total_rep_count, + rep_blocks, + THREADS_PER_BLOCK, + false, + 0, + 0.0, + false, + cudf::detail::make_pinned_vector_async(0, stream), + child_schema_idx, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + stream, + mr, + d_rep_top_row_indices.data(), + propagate_invalid_rows); + } else if (elem_type_id == cudf::type_id::STRING || elem_type_id == cudf::type_id::LIST) { + if (elem_type_id == cudf::type_id::STRING && + schema[child_schema_idx].encoding == proto_encoding::ENUM_STRING) { + if (child_schema_idx < static_cast(enum_valid_values.size()) && + child_schema_idx < static_cast(enum_names.size()) && + !enum_valid_values[child_schema_idx].empty() && + enum_valid_values[child_schema_idx].size() == enum_names[child_schema_idx].size()) { + auto lookup = make_enum_string_lookup_tables( + enum_valid_values[child_schema_idx], enum_names[child_schema_idx], stream, mr); + rmm::device_uvector enum_values(total_rep_count, stream, mr); + rmm::device_uvector valid((total_rep_count > 0 ? total_rep_count : 1), stream, mr); + extract_varint_kernel + <<>>(message_data, + nr_loc, + total_rep_count, + enum_values.data(), + valid.data(), + d_error.data(), + false, + 0); + + rmm::device_uvector d_elem_has_invalid_enum(total_rep_count, stream, mr); + thrust::fill(rmm::exec_policy_nosync(stream), + d_elem_has_invalid_enum.begin(), + d_elem_has_invalid_enum.end(), + false); + launch_validate_enum_values(enum_values.data(), + valid.data(), + d_elem_has_invalid_enum.data(), + lookup.d_valid_enums.data(), + static_cast(lookup.d_valid_enums.size()), + total_rep_count, + stream); + propagate_invalid_enum_flags_to_rows(d_elem_has_invalid_enum, + d_row_force_null, + total_rep_count, + d_rep_top_row_indices.data(), + propagate_invalid_rows, + stream); + child_values = + build_enum_string_values_column(enum_values, valid, lookup, total_rep_count, stream, mr); + } else { + set_error_once_async(d_error.data(), ERR_MISSING_ENUM_META, stream); + child_values = make_null_column(cudf::data_type{elem_type_id}, total_rep_count, stream, mr); + } + } else { + bool as_bytes = (elem_type_id == cudf::type_id::LIST); + auto valid_fn = [] __device__(cudf::size_type) { return true; }; + auto empty_default = cudf::detail::make_pinned_vector_async(0, stream); + child_values = extract_and_build_string_or_bytes_column(as_bytes, + message_data, + total_rep_count, + nr_loc, + nr_loc, + valid_fn, + false, + empty_default, + d_error, + stream, + mr); + } + } else if (elem_type_id == cudf::type_id::STRUCT) { + auto gc_indices = find_child_field_indices(schema, num_fields, child_schema_idx); + if (gc_indices.empty()) { + child_values = cudf::make_structs_column(total_rep_count, + std::vector>{}, + 0, + rmm::device_buffer{}, + stream, + mr); + } else { + rmm::device_uvector d_virtual_row_offsets(total_rep_count, stream, mr); + rmm::device_uvector d_virtual_parent_locs(total_rep_count, stream, mr); + launch_compute_virtual_parents_for_nested_repeated(d_rep_occs.data(), + row_offsets, + parent_locs, + d_virtual_row_offsets.data(), + d_virtual_parent_locs.data(), + total_rep_count, + d_error.data(), + stream); + + child_values = build_nested_struct_column(message_data, + message_data_size, + d_virtual_row_offsets.data(), + base_offset, + d_virtual_parent_locs, + gc_indices, + schema, + num_fields, + default_ints, + default_floats, + default_bools, + default_strings, + enum_valid_values, + enum_names, + d_row_force_null, + d_error, + total_rep_count, + stream, + mr, + d_rep_top_row_indices.data(), + depth + 1, + propagate_invalid_rows); + } + } else { + child_values = make_empty_column_safe(cudf::data_type{elem_type_id}, stream, mr); + } + + auto list_offs_col = std::make_unique(cudf::data_type{cudf::type_id::INT32}, + num_parent_rows + 1, + list_offs.release(), + rmm::device_buffer{}, + 0); + return cudf::make_lists_column( + num_parent_rows, std::move(list_offs_col), std::move(child_values), 0, rmm::device_buffer{}); +} + } // namespace spark_rapids_jni::protobuf::detail diff --git a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp index d1179dc858..c228329db2 100644 --- a/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp +++ b/src/main/cpp/src/protobuf/protobuf_host_helpers.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include namespace spark_rapids_jni::protobuf::detail { @@ -98,6 +99,20 @@ std::unique_ptr make_empty_list_column(std::unique_ptr +inline cudf::type_id get_output_type_id(FieldT const& field) +{ + if constexpr (std::is_same_v) { + return static_cast(field.output_type_id); + } else { + return field.output_type; + } +} + template std::unique_ptr make_empty_struct_column_with_schema( SchemaT const& schema, @@ -110,7 +125,7 @@ std::unique_ptr make_empty_struct_column_with_schema( std::vector> children; for (int child_idx : child_indices) { - auto child_type = cudf::data_type{schema[child_idx].output_type}; + auto child_type = cudf::data_type{get_output_type_id(schema[child_idx])}; std::unique_ptr child_col; if (child_type.id() == cudf::type_id::STRUCT) { diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cu b/src/main/cpp/src/protobuf/protobuf_kernels.cu index ddd09e881f..2b771302b3 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cu +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cu @@ -17,6 +17,7 @@ #include "protobuf/protobuf_kernels.cuh" #include +#include #include #include @@ -31,30 +32,1302 @@ namespace spark_rapids_jni::protobuf::detail { namespace { +// ============================================================================ +// Pass 1: Scan all fields kernel - records (offset, length) for each field +// ============================================================================ + CUDF_KERNEL void set_error_if_unset_kernel(int* error_flag, int error_code) { if (blockIdx.x == 0 && threadIdx.x == 0) { set_error_once(error_flag, error_code); } } -// Stub kernels — replaced with real implementations in follow-up PRs. -CUDF_KERNEL void check_required_fields_kernel(field_location const*, - uint8_t const*, - int, - int, - cudf::bitmask_type const*, - cudf::size_type, - field_location const*, - bool*, - int32_t const*, - int*) +/** + * Fused scanning kernel: scans each message once and records the location + * of all requested fields. + * + * For "last one wins" semantics (protobuf standard for repeated scalars), + * we continue scanning even after finding a field. + * + * If a row hits a parse error that leaves the cursor in an unsafe state (for example, malformed + * varint bytes or a schema-matching field with the wrong wire type), the scan aborts for that row + * instead of guessing where the next field begins. In permissive mode the caller may also supply a + * row-level invalidity buffer so the full struct row can be nulled to match Spark CPU semantics for + * malformed messages. + */ +CUDF_KERNEL void scan_all_fields_kernel( + cudf::column_device_view const d_in, + field_descriptor const* field_descs, // [num_fields] + int num_fields, + int const* field_lookup, // direct-mapped lookup table (nullable) + int field_lookup_size, // size of lookup table (0 if null) + field_location* locations, // [num_rows * num_fields] row-major + int* error_flag, + bool* row_has_invalid_data) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) { return; } + + auto mark_row_error = [&]() { + if (row_has_invalid_data != nullptr) { row_has_invalid_data[row] = true; } + }; + + for (int f = 0; f < num_fields; f++) { + locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = {-1, 0}; + } + + if (in.nullable() && in.is_null(row)) { return; } + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + + if (!check_message_bounds(start, end, child.size(), error_flag)) { + mark_row_error(); + return; + } + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) { + mark_row_error(); + return; + } + int fn = tag.field_number; + int wt = tag.wire_type; + + int f = lookup_field(fn, field_lookup, field_lookup_size, field_descs, num_fields); + if (f >= 0) { + if (wt != field_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + mark_row_error(); + return; + } + + // Record the location (relative to message start) + int data_offset = static_cast(cur - bytes - start); + + if (wt == wire_type_value(proto_wire_type::LEN)) { + // For length-delimited, record offset after length prefix and the data length + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + mark_row_error(); + return; + } + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(cuda::std::numeric_limits::max())) { + set_error_once(error_flag, ERR_OVERFLOW); + mark_row_error(); + return; + } + // Record offset pointing to the actual data (after length prefix) + int32_t data_location; + if (!checked_add_int32(data_offset, len_bytes, data_location)) { + set_error_once(error_flag, ERR_OVERFLOW); + mark_row_error(); + return; + } + locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_location, static_cast(len)}; + } else { + // For fixed-size and varint fields, record offset and compute length + int field_size = get_wire_type_size(wt, cur, msg_end); + if (field_size < 0) { + set_error_once(error_flag, ERR_FIELD_SIZE); + mark_row_error(); + return; + } + locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_offset, field_size}; + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + mark_row_error(); + return; + } + cur = next; + } +} + +// ============================================================================ +// Shared device functions for repeated field processing +// ============================================================================ + +/** + * Count a single repeated field occurrence (packed or unpacked). + * Updates info.count and info.total_length. + * Returns false on error (error_flag set), true on success. + */ +__device__ bool count_repeated_element(uint8_t const* cur, + uint8_t const* msg_end, + int wt, + int expected_wt, + repeated_field_info& info, + int* error_flag) +{ + bool is_packed = (wt == wire_type_value(proto_wire_type::LEN) && + expected_wt != wire_type_value(proto_wire_type::LEN)); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } + + if (is_packed) { + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + uint8_t const* packed_start = cur + len_bytes; + if (packed_len > static_cast(msg_end - packed_start)) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + uint8_t const* packed_end = packed_start + packed_len; + + int count = 0; + if (expected_wt == wire_type_value(proto_wire_type::VARINT)) { + uint8_t const* p = packed_start; + while (p < packed_end) { + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + p += vbytes; + count++; + } + } else if (expected_wt == wire_type_value(proto_wire_type::I32BIT)) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + count = static_cast(packed_len / 4); + } else if (expected_wt == wire_type_value(proto_wire_type::I64BIT)) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + count = static_cast(packed_len / 8); + } + + info.count += count; + info.total_length += static_cast(packed_len); + } else { + int32_t data_offset, data_length; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return false; + } + info.count++; + info.total_length += data_length; + } + return true; +} + +/** + * Record a single repeated field occurrence (packed or unpacked). + * Writes to occurrences[write_idx] and advances write_idx. + * Offsets are computed relative to msg_base. + * Returns false on error (error_flag set), true on success. + */ +__device__ bool scan_repeated_element(uint8_t const* cur, + uint8_t const* msg_end, + uint8_t const* msg_base, + int wt, + int expected_wt, + int32_t row, + repeated_occurrence* occurrences, + int& write_idx, + int write_end, + int* error_flag) +{ + bool is_packed = (wt == wire_type_value(proto_wire_type::LEN) && + expected_wt != wire_type_value(proto_wire_type::LEN)); + + if (!is_packed && wt != expected_wt) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } + + if (is_packed) { + uint64_t packed_len; + int len_bytes; + if (!read_varint(cur, msg_end, packed_len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + uint8_t const* packed_start = cur + len_bytes; + if (packed_len > static_cast(msg_end - packed_start)) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + uint8_t const* packed_end = packed_start + packed_len; + + if (expected_wt == wire_type_value(proto_wire_type::VARINT)) { + uint8_t const* p = packed_start; + while (p < packed_end) { + int32_t elem_offset = static_cast(p - msg_base); + uint64_t dummy; + int vbytes; + if (!read_varint(p, packed_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } + occurrences[write_idx] = {row, elem_offset, vbytes}; + write_idx++; + p += vbytes; + } + } else if (expected_wt == wire_type_value(proto_wire_type::I32BIT)) { + if ((packed_len % 4) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + for (uint64_t i = 0; i < packed_len; i += 4) { + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } + occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 4}; + write_idx++; + } + } else if (expected_wt == wire_type_value(proto_wire_type::I64BIT)) { + if ((packed_len % 8) != 0) { + set_error_once(error_flag, ERR_FIXED_LEN); + return false; + } + for (uint64_t i = 0; i < packed_len; i += 8) { + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } + occurrences[write_idx] = {row, static_cast(packed_start - msg_base + i), 8}; + write_idx++; + } + } + } else { + int32_t data_offset, data_length; + if (!get_field_data_location(cur, msg_end, wt, data_offset, data_length)) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return false; + } + if (write_idx >= write_end) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return false; + } + int32_t abs_offset = static_cast(cur - msg_base) + data_offset; + occurrences[write_idx] = {row, abs_offset, data_length}; + write_idx++; + } + return true; +} + +// ============================================================================ +// Pass 1b: Count repeated fields kernel +// ============================================================================ + +/** + * Count occurrences of repeated fields in each row. + * Also records locations of nested message fields for hierarchical processing. + * + * Optional lookup tables (fn_to_rep_idx, fn_to_nested_idx) provide O(1) field_number + * to local index mapping. When nullptr, falls back to linear search. + */ +CUDF_KERNEL void count_repeated_fields_kernel(cudf::column_device_view const d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + // Initialize repeated counts to 0 + for (int f = 0; f < num_repeated_fields; f++) { + repeated_info[flat_index( + static_cast(row), static_cast(num_repeated_fields), static_cast(f))] = + {0, 0}; + } + + // Initialize nested locations to not found + for (int f = 0; f < num_nested_fields; f++) { + nested_locations[flat_index( + static_cast(row), static_cast(num_nested_fields), static_cast(f))] = { + -1, 0}; + } + + if (in.nullable() && in.is_null(row)) return; + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) return; + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) return; + int fn = tag.field_number; + int wt = tag.wire_type; + + if (fn_to_rep_idx != nullptr && fn > 0 && fn < fn_to_rep_size) { + int i = fn_to_rep_idx[fn]; + if (i >= 0) { + int schema_idx = repeated_field_indices[i]; + if (schema[schema_idx].depth == depth_level && + !count_repeated_element( + cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[flat_index(static_cast(row), + static_cast(num_repeated_fields), + static_cast(i))], + error_flag)) { + return; + } + } + } else { + for (int i = 0; i < num_repeated_fields; i++) { + int schema_idx = repeated_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + if (!count_repeated_element( + cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[flat_index(static_cast(row), + static_cast(num_repeated_fields), + static_cast(i))], + error_flag)) { + return; + } + } + } + } + + // Check nested message fields at this depth + auto handle_nested = [&](int i) { + if (wt != wire_type_value(proto_wire_type::LEN)) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + } + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return false; + } + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(cuda::std::numeric_limits::max())) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + auto const rel_offset64 = static_cast(cur - bytes - start); + if (rel_offset64 < cuda::std::numeric_limits::min() || + rel_offset64 > cuda::std::numeric_limits::max()) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + int32_t msg_offset; + if (!checked_add_int32(static_cast(rel_offset64), len_bytes, msg_offset)) { + set_error_once(error_flag, ERR_OVERFLOW); + return false; + } + nested_locations[flat_index( + static_cast(row), static_cast(num_nested_fields), static_cast(i))] = + {msg_offset, static_cast(len)}; + return true; + }; + + if (fn_to_nested_idx != nullptr && fn > 0 && fn < fn_to_nested_size) { + int i = fn_to_nested_idx[fn]; + if (i >= 0) { + int schema_idx = nested_field_indices[i]; + if (schema[schema_idx].depth == depth_level) { + if (!handle_nested(i)) return; + } + } + } else { + for (int i = 0; i < num_nested_fields; i++) { + int schema_idx = nested_field_indices[i]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].depth == depth_level) { + if (!handle_nested(i)) return; + } + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Combined occurrence scan: scans each message ONCE and writes occurrences for ALL + * repeated fields simultaneously, scanning each message only once. + */ +CUDF_KERNEL void scan_all_repeated_occurrences_kernel(cudf::column_device_view const d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + cudf::detail::lists_column_device_view in{d_in}; + if (row >= in.size()) return; + + if (in.nullable() && in.is_null(row)) return; + + auto const base = in.offset_at(0); + auto const child = in.get_sliced_child(); + auto const* bytes = reinterpret_cast(child.data()); + int32_t start = in.offset_at(row) - base; + int32_t end = in.offset_at(row + 1) - base; + if (!check_message_bounds(start, end, child.size(), error_flag)) return; + + uint8_t const* cur = bytes + start; + uint8_t const* msg_end = bytes + end; + + constexpr int MAX_STACK_FIELDS = 128; + if (num_scan_fields > MAX_STACK_FIELDS) { + set_error_once(error_flag, ERR_SCHEMA_TOO_LARGE); + return; + } + int write_idx[MAX_STACK_FIELDS]; + for (int f = 0; f < num_scan_fields; f++) { + write_idx[f] = scan_descs[f].row_offsets[row]; + } + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) return; + int fn = tag.field_number; + int wt = tag.wire_type; + + auto try_scan = [&](int f) -> bool { + int target_wt = scan_descs[f].wire_type; + bool is_packed = (wt == wire_type_value(proto_wire_type::LEN) && + target_wt != wire_type_value(proto_wire_type::LEN)); + if (is_packed || wt == target_wt) { + return scan_repeated_element(cur, + msg_end, + bytes + start, + wt, + target_wt, + static_cast(row), + scan_descs[f].occurrences, + write_idx[f], + scan_descs[f].row_offsets[row + 1], + error_flag); + } + set_error_once(error_flag, ERR_WIRE_TYPE); + return false; + }; + + if (fn_to_desc_idx != nullptr && fn > 0 && fn < fn_to_desc_size) { + int f = fn_to_desc_idx[fn]; + if (f >= 0 && f < num_scan_fields) { + if (!try_scan(f)) return; + } + } else { + for (int f = 0; f < num_scan_fields; f++) { + if (scan_descs[f].field_number == fn) { + if (!try_scan(f)) return; + } + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } + + for (int f = 0; f < num_scan_fields; f++) { + if (write_idx[f] != scan_descs[f].row_offsets[row + 1]) { + set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return; + } + } +} + +// ============================================================================ +// Nested message scanning kernels +// ============================================================================ + +/** + * Scan nested message fields. + * Each row represents a nested message at a specific parent location. + * This kernel finds fields within the nested message bytes. + * + * Note: this path intentionally keeps a linear child-field scan. Unlike repeated-message child + * scanning, previous benchmarking did not show a stable win from adding a host-built lookup table + * here, so we keep the simpler implementation unless that trade-off changes. + */ +CUDF_KERNEL void scan_nested_message_fields_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_parent_rows) return; + + for (int f = 0; f < num_fields; f++) { + output_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = {-1, 0}; + } + + auto const& parent_loc = parent_locations[row]; + if (parent_loc.offset < 0) return; + + auto parent_row_start = parent_row_offsets[row] - parent_base_offset; + int64_t nested_start_off = static_cast(parent_row_start) + parent_loc.offset; + int64_t nested_end_off = nested_start_off + parent_loc.length; + if (nested_start_off < 0 || nested_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + uint8_t const* nested_start = message_data + nested_start_off; + uint8_t const* nested_end = nested_start + parent_loc.length; + + uint8_t const* cur = nested_start; + + while (cur < nested_end) { + proto_tag tag; + if (!decode_tag(cur, nested_end, tag, error_flag)) return; + int fn = tag.field_number; + int wt = tag.wire_type; + + for (int f = 0; f < num_fields; f++) { + if (field_descs[f].field_number == fn) { + if (field_descs[f].is_repeated) { + // Repeated children are handled by the dedicated count/scan path, not by + // the direct-child location scan used for scalar/nested singleton fields. + break; + } + if (wt != field_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } + + int data_offset = static_cast(cur - nested_start); + + if (wt == wire_type_value(proto_wire_type::LEN)) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, nested_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + if (len > static_cast(nested_end - cur - len_bytes) || + len > static_cast(cuda::std::numeric_limits::max())) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + int32_t data_location; + if (!checked_add_int32(data_offset, len_bytes, data_location)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + output_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_location, static_cast(len)}; + } else { + int field_size = get_wire_type_size(wt, cur, nested_end); + if (field_size < 0) { + set_error_once(error_flag, ERR_FIELD_SIZE); + return; + } + output_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(f))] = { + data_offset, field_size}; + } + break; + } + } + + uint8_t const* next; + if (!skip_field(cur, nested_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } +} + +/** + * Scan for child fields within repeated message occurrences. + * Each occurrence is a protobuf message, and we need to find child field locations within it. + */ +CUDF_KERNEL void scan_repeated_message_children_kernel( + uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, // Row offset for each occurrence + field_location const* + msg_locs, // Location of each message occurrence (offset within row, length) + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, // Output: [num_occurrences * num_child_fields] + int* error_flag, + int const* child_lookup, + int child_lookup_size) { + auto occ_idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (occ_idx >= num_occurrences) return; + + // Initialize child locations to not found + for (int f = 0; f < num_child_fields; f++) { + child_locs[flat_index(static_cast(occ_idx), + static_cast(num_child_fields), + static_cast(f))] = {-1, 0}; + } + + auto const& msg_loc = msg_locs[occ_idx]; + if (msg_loc.offset < 0) return; + + cudf::size_type row_offset = msg_row_offsets[occ_idx]; + int64_t msg_start_off = static_cast(row_offset) + msg_loc.offset; + int64_t msg_end_off = msg_start_off + msg_loc.length; + if (msg_start_off < 0 || msg_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + uint8_t const* msg_start = message_data + msg_start_off; + uint8_t const* msg_end = msg_start + msg_loc.length; + + uint8_t const* cur = msg_start; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) return; + int fn = tag.field_number; + int wt = tag.wire_type; + + int f = lookup_field(fn, child_lookup, child_lookup_size, child_descs, num_child_fields); + if (f >= 0) { + if (child_descs[f].is_repeated) { + // Repeated children are decoded by build_repeated_child_list_column via the + // nested repeated count/scan kernels, so do not record a singleton location here. + } else if (wt != child_descs[f].expected_wire_type) { + set_error_once(error_flag, ERR_WIRE_TYPE); + return; + } else { + int data_offset = static_cast(cur - msg_start); + + if (wt == wire_type_value(proto_wire_type::LEN)) { + uint64_t len; + int len_bytes; + if (!read_varint(cur, msg_end, len, len_bytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + if (len > static_cast(msg_end - cur - len_bytes) || + len > static_cast(cuda::std::numeric_limits::max())) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + int32_t data_location; + if (!checked_add_int32(data_offset, len_bytes, data_location)) { + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + child_locs[flat_index(static_cast(occ_idx), + static_cast(num_child_fields), + static_cast(f))] = {data_location, + static_cast(len)}; + } else { + // For varint/fixed types, store offset and estimated length + int32_t data_length = 0; + if (wt == wire_type_value(proto_wire_type::VARINT)) { + uint64_t dummy; + int vbytes; + if (!read_varint(cur, msg_end, dummy, vbytes)) { + set_error_once(error_flag, ERR_VARINT); + return; + } + data_length = vbytes; + } else if (wt == wire_type_value(proto_wire_type::I32BIT)) { + if (msg_end - cur < 4) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + data_length = 4; + } else if (wt == wire_type_value(proto_wire_type::I64BIT)) { + if (msg_end - cur < 8) { + set_error_once(error_flag, ERR_FIXED_LEN); + return; + } + data_length = 8; + } + child_locs[flat_index(static_cast(occ_idx), + static_cast(num_child_fields), + static_cast(f))] = {data_offset, data_length}; + } + } + } + + // Skip to next field + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } } -CUDF_KERNEL void validate_enum_values_kernel(int32_t const*, bool*, bool*, int32_t const*, int, int) +/** + * Count repeated field occurrences within nested messages. + * Similar to count_repeated_fields_kernel but operates on nested message locations. + * + * Note: unlike the top-level count_repeated_fields_kernel, this kernel does not perform + * a depth-level check because it operates within a specific parent message context where + * the depth is implicitly fixed. Callers must pre-filter repeated_indices to include only + * fields at the expected child depth. + */ +CUDF_KERNEL void count_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag) { + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + // Initialize counts + for (int ri = 0; ri < num_repeated; ri++) { + repeated_info[flat_index( + static_cast(row), static_cast(num_repeated), static_cast(ri))] = {0, + 0}; + } + + auto const& parent_loc = parent_locs[row]; + if (parent_loc.offset < 0) return; + + cudf::size_type row_off; + row_off = row_offsets[row] - base_offset; + + int64_t msg_start_off = static_cast(row_off) + parent_loc.offset; + int64_t msg_end_off = msg_start_off + parent_loc.length; + if (msg_start_off < 0 || msg_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + + uint8_t const* msg_start = message_data + msg_start_off; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) return; + int fn = tag.field_number; + int wt = tag.wire_type; + + // After decode_tag, `cur` points past the tag bytes (at the field data). + // Both count_repeated_element and skip_field expect this post-tag position. + for (int ri = 0; ri < num_repeated; ri++) { + int schema_idx = repeated_indices[ri]; + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + if (!count_repeated_element(cur, + msg_end, + wt, + schema[schema_idx].wire_type, + repeated_info[flat_index(static_cast(row), + static_cast(num_repeated), + static_cast(ri))], + error_flag)) { + return; + } + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } } -} // namespace +/** + * Scan for repeated field occurrences of a single repeated field within nested + * messages. Must be launched once per repeated field — the caller passes + * exactly one schema index via repeated_indices[0]. + * + * Note: no depth-level check is performed; see count_repeated_in_nested_kernel comment. + */ +CUDF_KERNEL void scan_repeated_in_nested_kernel(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + int write_idx = occ_prefix_sums[row]; + int write_end = occ_prefix_sums[row + 1]; + auto const& parent_loc = parent_locs[row]; + if (parent_loc.offset < 0) { + if (write_idx != write_end) set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); + return; + } + + cudf::size_type row_off = row_offsets[row] - base_offset; + + int64_t msg_start_off = static_cast(row_off) + parent_loc.offset; + int64_t msg_end_off = msg_start_off + parent_loc.length; + if (msg_start_off < 0 || msg_end_off > message_data_size) { + set_error_once(error_flag, ERR_BOUNDS); + return; + } + + uint8_t const* msg_start = message_data + msg_start_off; + uint8_t const* msg_end = msg_start + parent_loc.length; + uint8_t const* cur = msg_start; + + int schema_idx = repeated_indices[0]; + + while (cur < msg_end) { + proto_tag tag; + if (!decode_tag(cur, msg_end, tag, error_flag)) return; + int fn = tag.field_number; + int wt = tag.wire_type; + + if (schema[schema_idx].field_number == fn && schema[schema_idx].is_repeated) { + if (!scan_repeated_element(cur, + msg_end, + msg_start, + wt, + schema[schema_idx].wire_type, + static_cast(row), + occurrences, + write_idx, + write_end, + error_flag)) { + return; + } + } + + uint8_t const* next; + if (!skip_field(cur, msg_end, wt, next)) { + set_error_once(error_flag, ERR_SKIP); + return; + } + cur = next; + } + + if (write_idx != write_end) set_error_once(error_flag, ERR_REPEATED_COUNT_MISMATCH); +} + +/** + * Kernel to compute nested struct locations from child field locations. + * Replaces host-side loop that was copying data D->H, processing, then H->D. + * This is a critical performance optimization. + */ +CUDF_KERNEL void compute_nested_struct_locations_kernel( + field_location const* child_locs, // Child field locations from parent scan + field_location const* msg_locs, // Parent message locations + cudf::size_type const* msg_row_offsets, // Parent message row offsets + int child_idx, // Which child field is the nested struct + int num_child_fields, // Total number of child fields per occurrence + field_location* nested_locs, // Output: nested struct locations + cudf::size_type* nested_row_offsets, // Output: nested struct row offsets + int total_count, + int* error_flag) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + nested_locs[idx] = child_locs[flat_index(static_cast(idx), + static_cast(num_child_fields), + static_cast(child_idx))]; + auto sum = static_cast(msg_row_offsets[idx]) + msg_locs[idx].offset; + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { + nested_locs[idx] = {-1, 0}; + nested_row_offsets[idx] = 0; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + nested_row_offsets[idx] = static_cast(sum); +} + +/** + * Kernel to compute absolute grandchild parent locations from parent and child locations. + * Computes: gc_parent_abs[i] = parent[i].offset + child[i * ncf + ci].offset + * This replaces host-side loop with D->H->D copy pattern. + */ +CUDF_KERNEL void compute_grandchild_parent_locations_kernel( + field_location const* parent_locs, // Parent locations (row count) + field_location const* child_locs, // Child locations (row * num_child_fields) + int child_idx, // Which child field + int num_child_fields, // Total child fields per row + field_location* gc_parent_abs, // Output: absolute grandchild parent locations + int num_rows, + int* error_flag) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + + auto const& parent_loc = parent_locs[row]; + auto const& child_loc = child_locs[flat_index(static_cast(row), + static_cast(num_child_fields), + static_cast(child_idx))]; + + if (parent_loc.offset >= 0 && child_loc.offset >= 0) { + // Absolute offset = parent offset + child's relative offset + auto sum = static_cast(parent_loc.offset) + child_loc.offset; + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { + gc_parent_abs[row] = {-1, 0}; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + gc_parent_abs[row].offset = static_cast(sum); + gc_parent_abs[row].length = child_loc.length; + } else { + gc_parent_abs[row] = {-1, 0}; + } +} + +/** + * Compute virtual parent row offsets and locations for repeated message occurrences + * inside nested messages. Each occurrence becomes a virtual "row" so that + * build_nested_struct_column can recursively process the children. + */ +CUDF_KERNEL void compute_virtual_parents_for_nested_repeated_kernel( + repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, // original binary input list offsets + field_location const* parent_locations, // parent nested message locations + cudf::size_type* virtual_row_offsets, // output: [total_count] + field_location* virtual_parent_locs, // output: [total_count] + int total_count, + int* error_flag) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto const& ploc = parent_locations[occ.row_idx]; + + virtual_row_offsets[idx] = row_list_offsets[occ.row_idx]; + + // Keep zero-length embedded messages as "present but empty". + // Protobuf allows an embedded message with length=0, which maps to a non-null + // struct with all-null children (not a null struct). + if (ploc.offset >= 0) { + auto sum = static_cast(ploc.offset) + occ.offset; + if (sum < cuda::std::numeric_limits::min() || + sum > cuda::std::numeric_limits::max()) { + virtual_parent_locs[idx] = {-1, 0}; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + virtual_parent_locs[idx] = {static_cast(sum), occ.length}; + } else { + virtual_parent_locs[idx] = {-1, 0}; + } +} + +/** + * Kernel to compute message locations and row offsets from repeated occurrences. + * Replaces host-side loop that processed occurrences. + */ +CUDF_KERNEL void compute_msg_locations_from_occurrences_kernel( + repeated_occurrence const* occurrences, // Repeated field occurrences + cudf::size_type const* list_offsets, // List offsets for rows + cudf::size_type base_offset, // Base offset to subtract + field_location* msg_locs, // Output: message locations + cudf::size_type* msg_row_offsets, // Output: message row offsets + int total_count, + int* error_flag) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total_count) return; + + auto const& occ = occurrences[idx]; + auto row_offset = static_cast(list_offsets[occ.row_idx]) - base_offset; + if (row_offset < cuda::std::numeric_limits::min() || + row_offset > cuda::std::numeric_limits::max()) { + msg_row_offsets[idx] = 0; + msg_locs[idx] = {-1, 0}; + set_error_once(error_flag, ERR_OVERFLOW); + return; + } + msg_row_offsets[idx] = static_cast(row_offset); + msg_locs[idx] = {occ.offset, occ.length}; +} + +/** + * Extract a single field's locations from a 2D strided array on the GPU. + * Replaces a D2H + CPU loop + H2D pattern for nested message location extraction. + */ +CUDF_KERNEL void extract_strided_locations_kernel(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + parent_locs[row] = nested_locations[flat_index( + static_cast(row), static_cast(num_fields), static_cast(field_idx))]; +} + +// ============================================================================ +// Kernel to check required fields after scan pass +// ============================================================================ + +/** + * Check if any required fields are missing (offset < 0) and set error flag. + * This is called after the scan pass to validate required field constraints. + */ +CUDF_KERNEL void check_required_fields_kernel( + field_location const* locations, // [num_rows * num_fields] + uint8_t const* is_required, // [num_fields] (1 = required, 0 = optional) + int num_fields, + int num_rows, + cudf::bitmask_type const* input_null_mask, // optional top-level input null mask + cudf::size_type input_offset, // bit offset for sliced top-level input + field_location const* parent_locs, // [num_rows] optional parent presence for nested rows + bool* row_force_null, // [top_level_num_rows] optional permissive row nulling + int32_t const* top_row_indices, // [num_rows] optional nested-row -> top-row mapping + int* error_flag) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + if (input_null_mask != nullptr && !cudf::bit_is_set(input_null_mask, row + input_offset)) { + return; + } + if (parent_locs != nullptr && parent_locs[row].offset < 0) return; + + for (int f = 0; f < num_fields; f++) { + if (is_required[f] != 0 && locations[flat_index(static_cast(row), + static_cast(num_fields), + static_cast(f))] + .offset < 0) { + if (row_force_null != nullptr) { + auto const top_row = + top_row_indices != nullptr ? top_row_indices[row] : static_cast(row); + row_force_null[top_row] = true; + } + // Required field is missing - set error flag + set_error_once(error_flag, ERR_REQUIRED); + return; // No need to check other fields for this row + } + } +} + +/** + * Validate enum values against a set of valid values. + * If a value is not in the valid set: + * 1. Mark the field as invalid (valid[row] = false) + * 2. Mark the row as having an invalid enum (row_has_invalid_enum[row] = true) + * + * This matches Spark CPU PERMISSIVE mode behavior: when an unknown enum value is + * encountered, the entire struct row is set to null (not just the enum field). + * + * The valid_values array must be sorted for binary search. + * + * @note Time complexity: O(log(num_valid_values)) per row. + + */ +CUDF_KERNEL void validate_enum_values_kernel( + int32_t const* values, // [num_rows] extracted enum values + bool* valid, // [num_rows] field validity flags (will be modified) + bool* row_has_invalid_enum, // [num_rows] row-level invalid enum flag (will be set to true) + int32_t const* valid_enum_values, // sorted array of valid enum values + int num_valid_values, // size of valid_enum_values + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + // Skip if already invalid (field was missing) - missing field is not an enum error + if (!valid[row]) return; + + int32_t val = values[row]; + + // Binary search for the value in valid_enum_values + int left = 0; + int right = num_valid_values - 1; + bool found = false; + + while (left <= right) { + int mid = left + (right - left) / 2; + if (valid_enum_values[mid] == val) { + found = true; + break; + } else if (valid_enum_values[mid] < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // If not found, mark as invalid + if (!found) { + valid[row] = false; + // Also mark the row as having an invalid enum - this will null the entire struct row + row_has_invalid_enum[row] = true; + } +} + +/** + * Compute output UTF-8 length for enum-as-string rows. + * Invalid/missing values produce length 0 (null row/field semantics handled by valid[] and + * row_has_invalid_enum). + + */ +CUDF_KERNEL void compute_enum_string_lengths_kernel( + int32_t const* values, // [num_rows] enum numeric values + bool const* valid, // [num_rows] field validity + int32_t const* valid_enum_values, // sorted enum numeric values + int32_t const* enum_name_offsets, // [num_valid_values + 1] + int num_valid_values, + int32_t* lengths, // [num_rows] + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + + if (!valid[row]) { + lengths[row] = 0; + return; + } + + int32_t val = values[row]; + int left = 0; + int right = num_valid_values - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + int32_t mid_val = valid_enum_values[mid]; + if (mid_val == val) { + lengths[row] = enum_name_offsets[mid + 1] - enum_name_offsets[mid]; + return; + } else if (mid_val < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } + + // Should not happen when validate_enum_values_kernel has already run, but keep safe. + lengths[row] = 0; +} + +/** + * Copy enum-as-string UTF-8 bytes into output chars buffer using precomputed row offsets. + */ +CUDF_KERNEL void copy_enum_string_chars_kernel( + int32_t const* values, // [num_rows] enum numeric values + bool const* valid, // [num_rows] field validity + int32_t const* valid_enum_values, // sorted enum numeric values + int32_t const* enum_name_offsets, // [num_valid_values + 1] + uint8_t const* enum_name_chars, // concatenated enum UTF-8 names + int num_valid_values, + int32_t const* output_offsets, // [num_rows + 1] + char* out_chars, // [total_chars] + int num_rows) +{ + auto row = static_cast(blockIdx.x * blockDim.x + threadIdx.x); + if (row >= num_rows) return; + if (!valid[row]) return; + + int32_t val = values[row]; + int left = 0; + int right = num_valid_values - 1; + while (left <= right) { + int mid = left + (right - left) / 2; + int32_t mid_val = valid_enum_values[mid]; + if (mid_val == val) { + int32_t src_begin = enum_name_offsets[mid]; + int32_t src_end = enum_name_offsets[mid + 1]; + int32_t dst_begin = output_offsets[row]; + memcpy(out_chars + dst_begin, + enum_name_chars + src_begin, + static_cast(src_end - src_begin)); + return; + } else if (mid_val < val) { + left = mid + 1; + } else { + right = mid - 1; + } + } +} + +} // anonymous namespace + +// ============================================================================ +// Host wrapper functions — callable from other translation units +// ============================================================================ void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view stream) { @@ -62,6 +1335,344 @@ void set_error_once_async(int* error_flag, int error_code, rmm::cuda_stream_view CUDF_CUDA_TRY(cudaPeekAtLastError()); } +void launch_scan_all_fields(cudf::column_device_view const& d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag, + bool* row_has_invalid_data, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_all_fields_kernel<<>>(d_in, + field_descs, + num_fields, + field_lookup, + field_lookup_size, + locations, + error_flag, + row_has_invalid_data); +} + +void launch_count_repeated_fields(cudf::column_device_view const& d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + count_repeated_fields_kernel<<>>( + d_in, + schema, + num_fields, + depth_level, + repeated_info, + num_repeated_fields, + repeated_field_indices, + nested_locations, + num_nested_fields, + nested_field_indices, + error_flag, + fn_to_rep_idx, + fn_to_rep_size, + fn_to_nested_idx, + fn_to_nested_size); +} + +void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_all_repeated_occurrences_kernel<<>>( + d_in, scan_descs, num_scan_fields, error_flag, fn_to_desc_idx, fn_to_desc_size); +} + +void launch_extract_strided_locations(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + extract_strided_locations_kernel<<>>( + nested_locations, field_idx, num_fields, parent_locs, num_rows); +} + +void launch_scan_nested_message_fields(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_parent_rows == 0) return; + auto const blocks = + static_cast((num_parent_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_nested_message_fields_kernel<<>>( + message_data, + message_data_size, + parent_row_offsets, + parent_base_offset, + parent_locations, + num_parent_rows, + field_descs, + num_fields, + output_locations, + error_flag); +} + +void launch_scan_repeated_message_children(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag, + int const* child_lookup, + int child_lookup_size, + rmm::cuda_stream_view stream) +{ + if (num_occurrences == 0) return; + auto const blocks = + static_cast((num_occurrences + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_repeated_message_children_kernel<<>>( + message_data, + message_data_size, + msg_row_offsets, + msg_locs, + num_occurrences, + child_descs, + num_child_fields, + child_locs, + error_flag, + child_lookup, + child_lookup_size); +} + +void launch_count_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + count_repeated_in_nested_kernel<<>>( + message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_rows, + schema, + num_fields, + repeated_info, + num_repeated, + repeated_indices, + error_flag); +} + +void launch_scan_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + scan_repeated_in_nested_kernel<<>>( + message_data, + message_data_size, + row_offsets, + base_offset, + parent_locs, + num_rows, + schema, + occ_prefix_sums, + repeated_indices, + occurrences, + error_flag); +} + +void launch_compute_nested_struct_locations(field_location const* child_locs, + field_location const* msg_locs, + cudf::size_type const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + cudf::size_type* nested_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (total_count == 0) return; + auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_nested_struct_locations_kernel<<>>( + child_locs, + msg_locs, + msg_row_offsets, + child_idx, + num_child_fields, + nested_locs, + nested_row_offsets, + total_count, + error_flag); +} + +void launch_compute_grandchild_parent_locations(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_grandchild_parent_locations_kernel<<>>( + parent_locs, child_locs, child_idx, num_child_fields, gc_parent_abs, num_rows, error_flag); +} + +void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (total_count == 0) return; + auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_virtual_parents_for_nested_repeated_kernel<<>>(occurrences, + row_list_offsets, + parent_locations, + virtual_row_offsets, + virtual_parent_locs, + total_count, + error_flag); +} + +void launch_compute_msg_locations_from_occurrences(repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + cudf::size_type* msg_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream) +{ + if (total_count == 0) return; + auto const blocks = static_cast((total_count + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_msg_locations_from_occurrences_kernel<<>>( + occurrences, list_offsets, base_offset, msg_locs, msg_row_offsets, total_count, error_flag); +} + +void launch_validate_enum_values(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + validate_enum_values_kernel<<>>( + values, valid, row_has_invalid_enum, valid_enum_values, num_valid_values, num_rows); +} + +void launch_compute_enum_string_lengths(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + compute_enum_string_lengths_kernel<<>>( + values, valid, valid_enum_values, enum_name_offsets, num_valid_values, lengths, num_rows); +} + +void launch_copy_enum_string_chars(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows, + rmm::cuda_stream_view stream) +{ + if (num_rows == 0) return; + auto const blocks = static_cast((num_rows + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); + copy_enum_string_chars_kernel<<>>(values, + valid, + valid_enum_values, + enum_name_offsets, + enum_name_chars, + num_valid_values, + output_offsets, + out_chars, + num_rows); +} + void maybe_check_required_fields(field_location const* locations, std::vector const& field_indices, std::vector const& schema, @@ -109,7 +1720,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ bool propagate_to_rows, rmm::cuda_stream_view stream) { - if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) { return; } + if (num_items == 0 || row_invalid.size() == 0 || !propagate_to_rows) return; if (top_row_indices == nullptr) { CUDF_EXPECTS(static_cast(num_items) <= row_invalid.size(), @@ -131,7 +1742,7 @@ void propagate_invalid_enum_flags_to_rows(rmm::device_uvector const& item_ [item_invalid = item_invalid.data(), top_row_indices, row_invalid = row_invalid.data()] __device__(int idx) { - if (item_invalid[idx]) { row_invalid[top_row_indices[idx]] = true; } + if (item_invalid[idx]) row_invalid[top_row_indices[idx]] = true; }); } @@ -144,7 +1755,7 @@ void validate_enum_and_propagate_rows(rmm::device_uvector const& values bool propagate_to_rows, rmm::cuda_stream_view stream) { - if (num_items == 0 || valid_enums.empty()) { return; } + if (num_items == 0 || valid_enums.empty()) return; auto const blocks = static_cast((num_items + THREADS_PER_BLOCK - 1u) / THREADS_PER_BLOCK); auto d_valid_enums = cudf::detail::make_device_uvector_async( diff --git a/src/main/cpp/src/protobuf/protobuf_kernels.cuh b/src/main/cpp/src/protobuf/protobuf_kernels.cuh index 38917b6a83..2ba87361d9 100644 --- a/src/main/cpp/src/protobuf/protobuf_kernels.cuh +++ b/src/main/cpp/src/protobuf/protobuf_kernels.cuh @@ -403,6 +403,173 @@ CUDF_KERNEL void extract_lengths_kernel(LocationProvider loc_provider, } } +// ============================================================================ +// Host wrapper declarations for kernel launches +// ============================================================================ + +void launch_scan_all_fields(cudf::column_device_view const& d_in, + field_descriptor const* field_descs, + int num_fields, + int const* field_lookup, + int field_lookup_size, + field_location* locations, + int* error_flag, + bool* row_has_invalid_data, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_count_repeated_fields(cudf::column_device_view const& d_in, + device_nested_field_descriptor const* schema, + int num_fields, + int depth_level, + repeated_field_info* repeated_info, + int num_repeated_fields, + int const* repeated_field_indices, + field_location* nested_locations, + int num_nested_fields, + int const* nested_field_indices, + int* error_flag, + int const* fn_to_rep_idx, + int fn_to_rep_size, + int const* fn_to_nested_idx, + int fn_to_nested_size, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_scan_all_repeated_occurrences(cudf::column_device_view const& d_in, + repeated_field_scan_desc const* scan_descs, + int num_scan_fields, + int* error_flag, + int const* fn_to_desc_idx, + int fn_to_desc_size, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_extract_strided_locations(field_location const* nested_locations, + int field_idx, + int num_fields, + field_location* parent_locs, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_scan_nested_message_fields(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* parent_row_offsets, + cudf::size_type parent_base_offset, + field_location const* parent_locations, + int num_parent_rows, + field_descriptor const* field_descs, + int num_fields, + field_location* output_locations, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_scan_repeated_message_children(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* msg_row_offsets, + field_location const* msg_locs, + int num_occurrences, + field_descriptor const* child_descs, + int num_child_fields, + field_location* child_locs, + int* error_flag, + int const* child_lookup, + int child_lookup_size, + rmm::cuda_stream_view stream); + +void launch_count_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int num_fields, + repeated_field_info* repeated_info, + int num_repeated, + int const* repeated_indices, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_scan_repeated_in_nested(uint8_t const* message_data, + cudf::size_type message_data_size, + cudf::size_type const* row_offsets, + cudf::size_type base_offset, + field_location const* parent_locs, + int num_rows, + device_nested_field_descriptor const* schema, + int32_t const* occ_prefix_sums, + int const* repeated_indices, + repeated_occurrence* occurrences, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_nested_struct_locations(field_location const* child_locs, + field_location const* msg_locs, + cudf::size_type const* msg_row_offsets, + int child_idx, + int num_child_fields, + field_location* nested_locs, + cudf::size_type* nested_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_grandchild_parent_locations(field_location const* parent_locs, + field_location const* child_locs, + int child_idx, + int num_child_fields, + field_location* gc_parent_abs, + int num_rows, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_virtual_parents_for_nested_repeated(repeated_occurrence const* occurrences, + cudf::size_type const* row_list_offsets, + field_location const* parent_locations, + cudf::size_type* virtual_row_offsets, + field_location* virtual_parent_locs, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_compute_msg_locations_from_occurrences(repeated_occurrence const* occurrences, + cudf::size_type const* list_offsets, + cudf::size_type base_offset, + field_location* msg_locs, + cudf::size_type* msg_row_offsets, + int total_count, + int* error_flag, + rmm::cuda_stream_view stream); + +void launch_validate_enum_values(int32_t const* values, + bool* valid, + bool* row_has_invalid_enum, + int32_t const* valid_enum_values, + int num_valid_values, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_compute_enum_string_lengths(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + int num_valid_values, + int32_t* lengths, + int num_rows, + rmm::cuda_stream_view stream); + +void launch_copy_enum_string_chars(int32_t const* values, + bool const* valid, + int32_t const* valid_enum_values, + int32_t const* enum_name_offsets, + uint8_t const* enum_name_chars, + int num_valid_values, + int32_t const* output_offsets, + char* out_chars, + int num_rows, + rmm::cuda_stream_view stream); + // ============================================================================ // Host-side template helpers that launch CUDA kernels // ============================================================================ @@ -840,6 +1007,7 @@ inline std::unique_ptr build_repeated_scalar_column( rmm::device_async_resource_ref mr) { auto const input_null_count = binary_input.null_count(); + auto const field_type_id = static_cast(field_desc.output_type_id); if (total_count == 0) { rmm::device_uvector offsets(num_rows + 1, stream, mr); @@ -849,9 +1017,7 @@ inline std::unique_ptr build_repeated_scalar_column( offsets.release(), rmm::device_buffer{}, 0); - auto elem_type = field_desc.output_type_id == static_cast(cudf::type_id::LIST) - ? cudf::type_id::UINT8 - : static_cast(field_desc.output_type_id); + auto elem_type = field_type_id == cudf::type_id::LIST ? cudf::type_id::UINT8 : field_type_id; auto child_col = make_empty_column_safe(cudf::data_type{elem_type}, stream, mr); if (input_null_count > 0) { @@ -915,11 +1081,7 @@ inline std::unique_ptr build_repeated_scalar_column( rmm::device_buffer{}, 0); auto child_col = std::make_unique( - cudf::data_type{static_cast(field_desc.output_type_id)}, - total_count, - values.release(), - rmm::device_buffer{}, - 0); + cudf::data_type{field_type_id}, total_count, values.release(), rmm::device_buffer{}, 0); if (input_null_count > 0) { auto null_mask = cudf::copy_bitmask(binary_input, stream, mr); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java index 16e836bb02..2f2b3b4cf1 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java @@ -16,97 +16,3401 @@ package com.nvidia.spark.rapids.jni; +import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.ColumnView; import ai.rapids.cudf.DType; import ai.rapids.cudf.HostColumnVector; +import ai.rapids.cudf.HostColumnVectorCore; +import ai.rapids.cudf.HostColumnVector.*; import ai.rapids.cudf.Table; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +/** + * Tests for the Protobuf GPU decoder. + * + * Test cases are inspired by Google's protobuf conformance test suite: + * https://github.com/protocolbuffers/protobuf/tree/main/conformance + */ +public class ProtobufTest { + + // ============================================================================ + // Helper methods for encoding protobuf wire format + // ============================================================================ + + /** Encode a value as a varint (variable-length integer). */ + private static byte[] encodeVarint(long value) { + long v = value; + byte[] tmp = new byte[10]; + int idx = 0; + while ((v & ~0x7FL) != 0) { + tmp[idx++] = (byte) ((v & 0x7F) | 0x80); + v >>>= 7; + } + tmp[idx++] = (byte) (v & 0x7F); + byte[] out = new byte[idx]; + System.arraycopy(tmp, 0, out, 0, idx); + return out; + } + + /** ZigZag encode a signed 32-bit integer, returning as unsigned long for varint encoding. */ + private static long zigzagEncode32(int n) { + return Integer.toUnsignedLong((n << 1) ^ (n >> 31)); + } + + /** ZigZag encode a signed 64-bit integer. */ + private static long zigzagEncode64(long n) { + return (n << 1) ^ (n >> 63); + } + + /** Encode a 32-bit value in little-endian (fixed32). */ + private static byte[] encodeFixed32(int v) { + return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(v).array(); + } + + /** Encode a 64-bit value in little-endian (fixed64). */ + private static byte[] encodeFixed64(long v) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putLong(v).array(); + } + + /** Encode a float in little-endian (fixed32 wire type). */ + private static byte[] encodeFloat(float f) { + return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(f).array(); + } + + /** Encode a double in little-endian (fixed64 wire type). */ + private static byte[] encodeDouble(double d) { + return ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN).putDouble(d).array(); + } + + /** Create a protobuf tag (field number + wire type). */ + private static byte[] tag(int fieldNumber, int wireType) { + return encodeVarint(((long) fieldNumber << 3) | wireType); + } + + // Wire type constants + private static final int WT_VARINT = 0; + private static final int WT_64BIT = 1; + private static final int WT_LEN = 2; + private static final int WT_32BIT = 5; + + private static Byte[] box(byte[] bytes) { + if (bytes == null) return null; + Byte[] out = new Byte[bytes.length]; + for (int i = 0; i < bytes.length; i++) { + out[i] = bytes[i]; + } + return out; + } + + private static Byte[] concat(Byte[]... parts) { + int len = 0; + for (Byte[] p : parts) if (p != null) len += p.length; + Byte[] out = new Byte[len]; + int pos = 0; + for (Byte[] p : parts) { + if (p != null) { + System.arraycopy(p, 0, out, pos, p.length); + pos += p.length; + } + } + return out; + } + + // ============================================================================ + // Helper methods for calling the unified API + // ============================================================================ + + private static ProtobufSchemaDescriptor makeScalarSchema(int[] fieldNumbers, int[] typeIds, + int[] encodings) { + int n = fieldNumbers.length; + int[] parentIndices = new int[n]; + int[] depthLevels = new int[n]; + int[] wireTypes = new int[n]; + boolean[] isRepeated = new boolean[n]; + boolean[] isRequired = new boolean[n]; + boolean[] hasDefault = new boolean[n]; + long[] defaultInts = new long[n]; + double[] defaultFloats = new double[n]; + boolean[] defaultBools = new boolean[n]; + byte[][] defaultStrings = new byte[n][]; + int[][] enumValid = new int[n][]; + byte[][][] enumNames = new byte[n][][]; + + java.util.Arrays.fill(parentIndices, -1); + for (int i = 0; i < n; i++) { + wireTypes[i] = deriveWireType(typeIds[i], encodings[i]); + } + return new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefault, + defaultInts, defaultFloats, defaultBools, defaultStrings, enumValid, enumNames); + } + + private static int deriveWireType(int typeId, int encoding) { + if (encoding == Protobuf.ENC_ENUM_STRING) return Protobuf.WT_VARINT; + if (typeId == DType.FLOAT32.getTypeId().getNativeId()) return Protobuf.WT_32BIT; + if (typeId == DType.FLOAT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; + if (typeId == DType.STRING.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (typeId == DType.LIST.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (typeId == DType.STRUCT.getTypeId().getNativeId()) return Protobuf.WT_LEN; + if (encoding == Protobuf.ENC_FIXED) { + if (typeId == DType.INT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; + return Protobuf.WT_32BIT; + } + return Protobuf.WT_VARINT; + } + + /** + * Test-only convenience: wrap raw parallel arrays into a ProtobufSchemaDescriptor + * and decode. Avoids verbose ProtobufSchemaDescriptor construction at every call site. + */ + private static ColumnVector decodeRaw(ColumnView binaryInput, + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, + boolean[] hasDefaultValue, long[] defaultInts, + double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, + boolean failOnErrors) { + return decodeRaw(binaryInput, fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, new byte[fieldNumbers.length][][], failOnErrors); + } + + private static ColumnVector decodeRaw(ColumnView binaryInput, + int[] fieldNumbers, int[] parentIndices, int[] depthLevels, + int[] wireTypes, int[] outputTypeIds, int[] encodings, + boolean[] isRepeated, boolean[] isRequired, + boolean[] hasDefaultValue, long[] defaultInts, + double[] defaultFloats, boolean[] defaultBools, + byte[][] defaultStrings, int[][] enumValidValues, + byte[][][] enumNames, + boolean failOnErrors) { + return Protobuf.decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, enumNames), + failOnErrors); + } + + /** + * Helper to decode all scalar fields using the unified API. + * Builds a flat schema (parentIndices=-1, depth=0, isRepeated=false for all fields). + */ + private static ColumnVector decodeScalarFields(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + int[][] enumValidValues, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] parentIndices = new int[numFields]; + int[] depthLevels = new int[numFields]; + int[] wireTypes = new int[numFields]; + boolean[] isRepeated = new boolean[numFields]; + + java.util.Arrays.fill(parentIndices, -1); + // depthLevels already initialized to 0 + // isRepeated already initialized to false + for (int i = 0; i < numFields; i++) { + wireTypes[i] = deriveWireType(typeIds[i], encodings[i]); + } + + return Protobuf.decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefaultValue, + defaultInts, defaultFloats, defaultBools, defaultStrings, enumValidValues, + new byte[fieldNumbers.length][][]), + failOnErrors); + } + + /** + * Helper method that wraps the unified API for tests that decode all scalar fields. + */ + private static ColumnVector decodeAllFields(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings) { + return decodeAllFields(binaryInput, fieldNumbers, typeIds, encodings, true); + } + + /** + * Helper method that wraps the unified API for tests that decode all scalar fields. + */ + private static ColumnVector decodeAllFields(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + new boolean[numFields], new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + new int[numFields][], failOnErrors); + } + + private static void assertSingleNullStructRow(ColumnVector actual, String message) { + try (HostColumnVector hostStruct = actual.copyToHost()) { + assertEquals(1, actual.getNullCount(), message); + assertTrue(hostStruct.isNull(0), "Row 0 should be null"); + } + } + + /** + * Helper method for tests with required field support. + */ + private static ColumnVector decodeAllFieldsWithRequired(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean[] isRequired, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + isRequired, new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + new int[numFields][], failOnErrors); + } + + // ============================================================================ + // Basic Type Tests + // ============================================================================ + + @Test + void decodeVarintAndStringToStruct() { + // message Msg { int64 id = 1; string name = 2; } + // Row0: id=100, name="alice" + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(100)), + box(tag(2, WT_LEN)), + box(encodeVarint(5)), + box("alice".getBytes())); + + // Row1: id=200, name missing + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(200))); + + // Row2: null input message + Byte[] row2 = null; + + try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); + ColumnVector expectedId = ColumnVector.fromBoxedLongs(100L, 200L, null); + ColumnVector expectedName = ColumnVector.fromStrings("alice", null, null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void decodeMoreTypes() { + // message Msg { uint32 u32 = 1; sint64 s64 = 2; fixed32 f32 = 3; bytes b = 4; } + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(4000000000L)), + box(tag(2, WT_VARINT)), + box(encodeVarint(zigzagEncode64(-1234567890123L))), + box(tag(3, WT_32BIT)), + box(encodeFixed32(12345)), + box(tag(4, WT_LEN)), + box(encodeVarint(3)), + box(new byte[]{1, 2, 3})); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); + ColumnVector expectedU32 = ColumnVector.fromBoxedLongs(4000000000L); + ColumnVector expectedS64 = ColumnVector.fromBoxedLongs(-1234567890123L); + ColumnVector expectedF32 = ColumnVector.fromBoxedInts(12345); + ColumnVector expectedB = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.UINT8)), + Arrays.asList((byte) 1, (byte) 2, (byte) 3)); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2, 3, 4}, + new int[]{ + DType.UINT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.LIST.getTypeId().getNativeId()}, + new int[]{ + Protobuf.ENC_DEFAULT, + Protobuf.ENC_ZIGZAG, + Protobuf.ENC_FIXED, + Protobuf.ENC_DEFAULT})) { + try (ColumnVector expectedU32Correct = expectedU32.castTo(DType.UINT32); + ColumnVector expectedStructCorrect = ColumnVector.makeStruct( + expectedU32Correct, expectedS64, expectedF32, expectedB)) { + AssertUtils.assertStructColumnsAreEqual(expectedStructCorrect, actualStruct); + } + } + } + + @Test + void decodeFloatDoubleAndBool() { + // message Msg { bool flag = 1; float f32 = 2; double f64 = 3; } + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), new Byte[]{(byte)0x01}, // bool=true + box(tag(2, WT_32BIT)), box(encodeFloat(3.14f)), + box(tag(3, WT_64BIT)), box(encodeDouble(2.71828))); + + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), new Byte[]{(byte)0x00}, // bool=false + box(tag(2, WT_32BIT)), box(encodeFloat(-1.5f)), + box(tag(3, WT_64BIT)), box(encodeDouble(0.0))); + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector expectedBool = ColumnVector.fromBoxedBooleans(true, false); + ColumnVector expectedFloat = ColumnVector.fromBoxedFloats(3.14f, -1.5f); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(2.71828, 0.0); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedBool, expectedFloat, expectedDouble); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{ + DType.BOOL8.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId(), + DType.FLOAT64.getTypeId().getNativeId()}, + new int[]{0, 0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Schema Projection Tests (new API feature) + // ============================================================================ + + @Test + void testSchemaProjection() { + // message Msg { int64 f1 = 1; string f2 = 2; int32 f3 = 3; } + // Only decode f1 and f3, f2 should be null + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100)), + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes()), + box(tag(3, WT_VARINT)), box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); + // Expected: f1=100, f3=42 (schema projection: only decode these two) + ColumnVector expectedF1 = ColumnVector.fromBoxedLongs(100L); + ColumnVector expectedF3 = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedF1, expectedF3); + // Decode only f1 (field_number=1) and f3 (field_number=3), skip f2 + // With the unified API, we only include the fields we want in the schema + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 3}, // field numbers for f1 and f3 + new int[]{DType.INT64.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSchemaProjectionDecodeNone() { + // Decode no fields - all should be null + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100)), + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0}).build(); + // With no fields in the schema, the GPU returns an empty struct + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{}, // no field numbers + new int[]{}, // no types + new int[]{})) { // no encodings + assertNotNull(actualStruct); + assertEquals(DType.STRUCT, actualStruct.getType()); + } + } + + // ============================================================================ + // Varint Boundary Tests + // ============================================================================ + + @Test + void testVarintMaxUint64() { + // Max uint64 = 0xFFFFFFFFFFFFFFFF = 18446744073709551615 + // Encoded as 10 bytes: FF FF FF FF FF FF FF FF FF 01 + Byte[] row = concat( + box(tag(1, WT_VARINT)), + new Byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0x01}); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.UINT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + try (ColumnVector expectedU64 = ColumnVector.fromBoxedLongs(-1L); // -1 as unsigned = max + ColumnVector expectedU64Correct = expectedU64.castTo(DType.UINT64); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedU64Correct)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + } + + // ============================================================================ + // Output shape tests — verify the stub produces correctly typed struct columns + // ============================================================================ + + @Test + void testEmptySchemaProducesEmptyStruct() { + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema(new int[]{}, new int[]{}, new int[]{}), true)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(0, result.getNumChildren()); + } + } + + @Test + void testVarintZero() { + // Zero encoded as single byte: 0x00 + Byte[] row = concat(box(tag(1, WT_VARINT)), new Byte[]{0x00}); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(0L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSingleScalarFieldOutputShape() { + Byte[] row = new Byte[]{0x08, 0x01}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), + makeScalarSchema( + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}), true)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(1, result.getRowCount()); + assertEquals(1, result.getNumChildren()); + assertEquals(DType.INT64, result.getChildColumnView(0).getType()); + } + } + + @Test + void testVarintOverEncodedZero() { + // Zero over-encoded as 10 bytes (all continuation bits except last) + // This is valid per protobuf spec - parsers must accept non-canonical varints + Byte[] row = concat( + box(tag(1, WT_VARINT)), + new Byte[]{(byte)0x80, (byte)0x80, (byte)0x80, (byte)0x80, (byte)0x80, + (byte)0x80, (byte)0x80, (byte)0x80, (byte)0x80, (byte)0x00}); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(0L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testVarint10thByteInvalid() { + // 10th byte with more than 1 significant bit is invalid + // (uint64 can only hold 64 bits: 9*7=63 bits + 1 bit from 10th byte) + Byte[] row = concat( + box(tag(1, WT_VARINT)), + new Byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0x02}); // 0x02 has 2nd bit set + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } + + // ============================================================================ + // ZigZag Boundary Tests + // ============================================================================ + + @Test + void testZigzagInt32Min() { + // int32 min = -2147483648 + // zigzag encoded = 4294967295 = 0xFFFFFFFF + int minInt32 = Integer.MIN_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode32(minInt32)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(minInt32); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagInt32Max() { + // int32 max = 2147483647 + // zigzag encoded = 4294967294 = 0xFFFFFFFE + int maxInt32 = Integer.MAX_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode32(maxInt32)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(maxInt32); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagInt64Min() { + // int64 min = -9223372036854775808 + long minInt64 = Long.MIN_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode64(minInt64)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(minInt64); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagInt64Max() { + long maxInt64 = Long.MAX_VALUE; + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode64(maxInt64)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(maxInt64); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testZigzagNegativeOne() { + // -1 zigzag encoded = 1 + Byte[] row = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(zigzagEncode64(-1L)))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(-1L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ZIGZAG})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Truncated/Malformed Data Tests + // ============================================================================ + + @Test + void testMalformedVarint() { + // Varint that never terminates (all continuation bits set, 11 bytes) + Byte[] malformed = new Byte[]{(byte)0x08, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + assertSingleNullStructRow(result, "Malformed varint should null the struct row"); + } + } + + @Test + void testTruncatedVarint() { + // Single byte with continuation bit set but no following byte + Byte[] truncated = concat(box(tag(1, WT_VARINT)), new Byte[]{(byte)0x80}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + assertSingleNullStructRow(result, "Truncated varint should null the struct row"); + } + } + + @Test + void testTruncatedLengthDelimited() { + // String field with length=5 but no actual data + Byte[] truncated = concat(box(tag(2, WT_LEN)), box(encodeVarint(5))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{2}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + assertSingleNullStructRow(result, + "Truncated length-delimited field should null the struct row"); + } + } + + @Test + void testTruncatedFixed32() { + // Fixed32 needs 4 bytes but only 3 provided + Byte[] truncated = concat(box(tag(1, WT_32BIT)), new Byte[]{0x01, 0x02, 0x03}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + false)) { + assertSingleNullStructRow(result, "Truncated fixed32 should null the struct row"); + } + } + + @Test + void testTruncatedFixed64() { + // Fixed64 needs 8 bytes but only 7 provided + Byte[] truncated = concat(box(tag(1, WT_64BIT)), + new Byte[]{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + false)) { + assertSingleNullStructRow(result, "Truncated fixed64 should null the struct row"); + } + } + + @Test + void testPartialLengthDelimitedData() { + // Length says 10 bytes but only 5 provided + Byte[] partial = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(10)), + box("hello".getBytes())); // only 5 bytes + try (Table input = new Table.TestBuilder().column(new Byte[][]{partial}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + assertSingleNullStructRow(result, + "Partial length-delimited payload should null the struct row"); + } + } + + // ============================================================================ + // Wrong Wire Type Tests + // ============================================================================ + + @Test + void testWrongWireType() { + // Expect varint (wire type 0) but provide fixed32 (wire type 5) + Byte[] wrongType = concat( + box(tag(1, WT_32BIT)), // wire type 5 instead of 0 + box(encodeFixed32(100))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{wrongType}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, // expects varint + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + assertSingleNullStructRow(result, "Wrong wire type should null the struct row"); + } + } + + @Test + void testWrongWireTypeForString() { + // Expect length-delimited (wire type 2) but provide varint (wire type 0) + Byte[] wrongType = concat( + box(tag(1, WT_VARINT)), + box(encodeVarint(12345))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{wrongType}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, // expects LEN + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + assertSingleNullStructRow(result, "Wrong wire type for string should null the struct row"); + } + } + + // ============================================================================ + // Unknown Field Skip Tests + // ============================================================================ + + @Test + void testSkipUnknownVarintField() { + // Unknown field 99 with varint, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_VARINT)), + box(encodeVarint(12345)), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); // known field + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSkipUnknownFixed64Field() { + // Unknown field 99 with fixed64, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_64BIT)), + box(encodeFixed64(0x123456789ABCDEF0L)), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSkipUnknownLengthDelimitedField() { + // Unknown field 99 with length-delimited data, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_LEN)), + box(encodeVarint(5)), + box("hello".getBytes()), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testSkipUnknownFixed32Field() { + // Unknown field 99 with fixed32, followed by known field 1 + Byte[] row = concat( + box(tag(99, WT_32BIT)), + box(encodeFixed32(12345)), // unknown field to skip + box(tag(1, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Last One Wins (Repeated Scalar Field) Tests + // ============================================================================ + + @Test + void testLastOneWins() { + // Same field appears multiple times - last value should win + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100)), + box(tag(1, WT_VARINT)), box(encodeVarint(200)), + box(tag(1, WT_VARINT)), box(encodeVarint(300))); // this should win + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs(300L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testLastOneWinsForString() { + // Same string field appears multiple times + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(5)), box("first".getBytes()), + box(tag(1, WT_LEN)), box(encodeVarint(6)), box("second".getBytes()), + box(tag(1, WT_LEN)), box(encodeVarint(4)), box("last".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("last"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Error Handling Tests + // ============================================================================ + + @Test + void testFailOnErrorsTrue() { + Byte[] malformed = new Byte[]{(byte)0x08, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFieldNumberZeroInvalid() { + // Field number 0 is reserved and invalid + Byte[] invalid = concat(box(tag(0, WT_VARINT)), box(encodeVarint(123))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{invalid}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + false)) { + assertSingleNullStructRow(result, "Field number zero should null the struct row"); + } + } + + @Test + void testEmptyMessage() { + // Empty message should result in null/default values for all fields + Byte[] empty = new Byte[0]; + try (Table input = new Table.TestBuilder().column(new Byte[][]{empty}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedLongs((Long)null); + ColumnVector expectedStr = ColumnVector.fromStrings((String)null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt, expectedStr); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Float/Double Special Values Tests + // ============================================================================ + + @Test + void testFloatSpecialValues() { + Byte[] rowInf = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.POSITIVE_INFINITY))); + Byte[] rowNegInf = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.NEGATIVE_INFINITY))); + Byte[] rowNaN = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.NaN))); + Byte[] rowMin = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.MIN_VALUE))); + Byte[] rowMax = concat(box(tag(1, WT_32BIT)), box(encodeFloat(Float.MAX_VALUE))); + + try (Table input = new Table.TestBuilder().column(rowInf, rowNegInf, rowNaN, rowMin, rowMax).build(); + ColumnVector expectedFloat = ColumnVector.fromBoxedFloats( + Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, Float.NaN, + Float.MIN_VALUE, Float.MAX_VALUE); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedFloat); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.FLOAT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDoubleSpecialValues() { + Byte[] rowInf = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.POSITIVE_INFINITY))); + Byte[] rowNegInf = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.NEGATIVE_INFINITY))); + Byte[] rowNaN = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.NaN))); + Byte[] rowMin = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.MIN_VALUE))); + Byte[] rowMax = concat(box(tag(1, WT_64BIT)), box(encodeDouble(Double.MAX_VALUE))); + + try (Table input = new Table.TestBuilder().column(rowInf, rowNegInf, rowNaN, rowMin, rowMax).build(); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles( + Double.POSITIVE_INFINITY, Double.NEGATIVE_INFINITY, Double.NaN, + Double.MIN_VALUE, Double.MAX_VALUE); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedDouble); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.FLOAT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Enum Tests (enums.as.ints=true semantics) + // ============================================================================ + + @Test + void testEnumAsInt() { + // message Msg { enum Color { RED=0; GREEN=1; BLUE=2; } Color c = 1; } + // c = GREEN (value 1) - encoded as varint + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(1); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumZeroValue() { + // Enum with value 0 (first/default enum value) + // c = RED (value 0) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(0); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumUnknownValue() { + // Protobuf allows unknown enum values - they should still be decoded as integers + // c = 999 (unknown value not in enum definition) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(999); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumNegativeValue() { + // Negative enum values are valid in protobuf (stored as unsigned varint) + // c = -1 (represented as 0xFFFFFFFF in protobuf wire format) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(-1L & 0xFFFFFFFFL))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(-1); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumMultipleFields() { + // message Msg { enum Status { OK=0; ERROR=1; } Status s1 = 1; int32 count = 2; Status s2 = 3; } + // s1 = ERROR (1), count = 42, s2 = OK (0) + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), // s1 = ERROR + box(tag(2, WT_VARINT)), box(encodeVarint(42)), // count = 42 + box(tag(3, WT_VARINT)), box(encodeVarint(0))); // s2 = OK + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedS1 = ColumnVector.fromBoxedInts(1); + ColumnVector expectedCount = ColumnVector.fromBoxedInts(42); + ColumnVector expectedS2 = ColumnVector.fromBoxedInts(0); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedS1, expectedCount, expectedS2); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumMissingField() { + // Enum field not present in message - should be null + Byte[] row = concat(box(tag(2, WT_VARINT)), box(encodeVarint(42))); // only count field + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedEnum = ColumnVector.fromBoxedInts((Integer) null); + ColumnVector expectedCount = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedEnum, expectedCount); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT})) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Required Field Tests + // ============================================================================ + + @Test + void testRequiredFieldPresent() { + // message Msg { required int64 id = 1; optional string name = 2; } + // Both fields present - should decode successfully + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(42)), + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedId = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedName = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedId, expectedName); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, // id is required, name is optional + true)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testRequiredFieldMissing_Permissive() { + // Required field missing in permissive mode - should null the whole row without exception + // message Msg { required int64 id = 1; optional string name = 2; } + // Only name field present, required id is missing + Byte[] row = concat( + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, // id is required, name is optional + false)) { // permissive mode - don't fail on errors + assertSingleNullStructRow(actualStruct, + "Missing top-level required field should null the row in PERMISSIVE mode"); + } + } + + @Test + void testRequiredFieldMissing_Failfast() { + // Required field missing in failfast mode - should throw exception + // message Msg { required int64 id = 1; optional string name = 2; } + // Only name field present, required id is missing + Byte[] row = concat( + box(tag(2, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT64.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, // id is required, name is optional + true)) { // failfast mode - should throw + } + }); + } + } + + @Test + void testMultipleRequiredFields_AllPresent() { + // message Msg { required int32 a = 1; required int64 b = 2; required string c = 3; } + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(10)), + box(tag(2, WT_VARINT)), box(encodeVarint(20)), + box(tag(3, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(10); + ColumnVector expectedB = ColumnVector.fromBoxedLongs(20L); + ColumnVector expectedC = ColumnVector.fromStrings("abc"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB, expectedC); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, true, true}, // all required + true)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testMultipleRequiredFields_SomeMissing_Failfast() { + // message Msg { required int32 a = 1; required int64 b = 2; required string c = 3; } + // Only field a is present, b and c are missing + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(10))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, true, true}, // all required + true)) { + } + }); + } + } + + @Test + void testOptionalFieldsOnly_NoValidation() { + // All fields optional - missing fields should not cause error + // message Msg { optional int32 a = 1; optional int64 b = 2; } + Byte[] row = new Byte[0]; // empty message + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts((Integer) null); + ColumnVector expectedB = ColumnVector.fromBoxedLongs((Long) null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // all optional + true)) { // even with failOnErrors=true, should succeed since all fields are optional + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testRequiredFieldWithMultipleRows() { + // Test required field validation across multiple rows + // Row 0: required field present + // Row 1: required field missing (should cause error in failfast mode) + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + Byte[] row1 = new Byte[0]; // empty - required field missing + + try (Table input = new Table.TestBuilder().column(row0, row1).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, // required + true)) { + } + }); + } + } + + @Test + void testRequiredFieldIgnoresNullInputRow_Failfast() { + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + Byte[] row1 = null; + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector actualStruct = decodeAllFieldsWithRequired( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + true); + ColumnVector idCol = actualStruct.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actualStruct.copyToHost(); + HostColumnVector hostId = idCol.copyToHost()) { + assertEquals(0, actualStruct.getNullCount(), "Null input rows keep the top-level struct row"); + assertFalse(hostStruct.isNull(0), "Present required field should keep row 0 valid"); + assertFalse(hostStruct.isNull(1), "Null input row should not trigger required-field failure"); + assertEquals(1, idCol.getNullCount(), "The required child value should be null on the null input row"); + assertTrue(hostId.isNull(1), "Null input row should produce a null child value, not ERR_REQUIRED"); + } + } + + @Test + void testRequiredNestedMessageMissing_Failfast() { + // message Outer { required Inner detail = 1; } + // message Inner { optional int32 id = 1; } + // Missing top-level required nested message should fail in FAILFAST mode. + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector ignored = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + true)) { + } + }); + } + } + + @Test + void testRequiredFieldInsideNestedMessageMissing_Failfast() { + // message Outer { optional Inner detail = 1; } + // message Inner { required int32 id = 1; optional string name = 2; } + // If detail is present but nested required id is missing, FAILFAST should throw. + Byte[] inner = concat( + box(tag(2, WT_LEN)), box(encodeVarint(4)), box("oops".getBytes())); + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(inner.length)), inner); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector ignored = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{WT_LEN, WT_VARINT, WT_LEN}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false}, + new boolean[]{false, true, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + true)) { + } + }); + } + } + + @Test + void testRequiredFieldInsideNestedMessageMissing_Permissive() { + // message Outer { optional Inner detail = 1; optional string name = 2; } + // message Inner { required int32 id = 1; optional string note = 2; } + // If detail is present but nested required id is missing, PERMISSIVE should null the row. + Byte[] inner = concat( + box(tag(2, WT_LEN)), box(encodeVarint(4)), box("oops".getBytes())); + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(inner.length)), inner, + box(tag(2, WT_LEN)), box(encodeVarint(7)), box("outside".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2, 2}, + new int[]{-1, 0, 0, -1}, + new int[]{0, 1, 1, 0}, + new int[]{WT_LEN, WT_VARINT, WT_LEN, WT_LEN}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false}, + new boolean[]{false, true, false, false}, + new boolean[]{false, false, false, false}, + new long[]{0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false}, + new byte[][]{null, null, null, null}, + new int[][]{null, null, null, null}, + false)) { + assertSingleNullStructRow(actual, + "Missing nested required field should null the outer row in PERMISSIVE mode"); + } + } + + @Test + void testAbsentNestedParentSkipsRequiredChildCheck_Failfast() { + // message Outer { optional Inner detail = 1; } + // message Inner { required int32 id = 1; } + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{false, true}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + true); + ColumnVector detailCol = actual.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostDetail = detailCol.copyToHost()) { + assertEquals(0, actual.getNullCount(), "Outer row should remain valid"); + assertFalse(hostStruct.isNull(0), "Top-level row should not be null"); + assertEquals(1, detailCol.getNullCount(), "Absent nested parent should stay null"); + assertTrue(hostDetail.isNull(0), "Missing optional nested struct should skip required-child error"); + } + } + + // ============================================================================ + // Default Value Tests (API accepts parameters, CUDA fill not yet implemented) + // ============================================================================ + + /** + * Helper method for tests with default value support. + */ + private static ColumnVector decodeAllFieldsWithDefaults(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + boolean[] isRequired, + boolean[] hasDefaultValue, + long[] defaultInts, + double[] defaultFloats, + boolean[] defaultBools, + byte[][] defaultStrings, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + isRequired, hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, new int[numFields][], failOnErrors); + } + + @Test + void testDefaultValueForMissingFields() { + // Test that missing fields with default values return the defaults + Byte[] row = new Byte[0]; // empty message + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + // With default values set, missing fields should return the default values + ColumnVector expectedA = ColumnVector.fromBoxedInts(42); + ColumnVector expectedB = ColumnVector.fromBoxedLongs(100L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // not required + new boolean[]{true, true}, // has default value + new long[]{42, 100}, // default int values (42, 100) + new double[]{0.0, 0.0}, // default float values (unused for int fields) + new boolean[]{false, false}, // default bool values (unused) + new byte[][]{null, null}, // default string values (unused) + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultValueFieldPresent_OverridesDefault() { + // When field is present, use the actual value (not the default) + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(99)), + box(tag(2, WT_VARINT)), box(encodeVarint(200))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(99); + ColumnVector expectedB = ColumnVector.fromBoxedLongs(200L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // not required + new boolean[]{true, true}, // has default value + new long[]{42, 100}, // default values - NOT used since field is present + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + false)) { + // Actual values should be used, not defaults + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultIntValue() { + // optional int32 count = 1 [default = 42]; + // Empty message should return the default value + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, // not required + new boolean[]{true}, // has default + new long[]{42}, // default = 42 + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultBoolValue() { + // optional bool flag = 1 [default = true]; + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedBool = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedBool); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.BOOL8.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{true}, // default = true + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultFloatValue() { + // optional double rate = 1 [default = 3.14]; + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(3.14); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedDouble); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.FLOAT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{3.14}, // default = 3.14 + new boolean[]{false}, + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultInt64Value() { + // optional int64 big_num = 1 [default = 9876543210]; + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(9876543210L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedLong); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{9876543210L}, // default = 9876543210 + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testMixedDefaultAndNonDefaultFields() { + // optional int32 a = 1 [default = 42]; + // optional int64 b = 2; (no default) + // optional bool c = 3 [default = true]; + // Empty message: a=42, b=null, c=true + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(42); + ColumnVector expectedB = ColumnVector.fromBoxedLongs((Long) null); // no default + ColumnVector expectedC = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB, expectedC); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2, 3}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.BOOL8.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false}, // not required + new boolean[]{true, false, true}, // a and c have defaults, b doesn't + new long[]{42, 0, 0}, // default for a + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, true}, // default for c + new byte[][]{null, null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultValueWithPartialMessage() { + // optional int32 a = 1 [default = 42]; + // optional int64 b = 2 [default = 100]; + // Message has only field b set, a should use default + Byte[] row = concat( + box(tag(2, WT_VARINT)), box(encodeVarint(999))); // b = 999 + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(42); // default + ColumnVector expectedB = ColumnVector.fromBoxedLongs(999L); // actual value + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedA, expectedB); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // not required + new boolean[]{true, true}, // both have defaults + new long[]{42, 100}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringValue() { + // optional string name = 1 [default = "hello"]; + // Empty message should return the default string + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, // not required + new boolean[]{true}, // has default + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"hello".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringValueEmpty() { + // optional string name = 1 [default = ""]; + // Empty message with empty default string + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings(""); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{new byte[0]}, // empty default string + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringValueWithPresent() { + // optional string name = 1 [default = "default"]; + // Message has actual value, should override default + byte[] strBytesRaw = "actual".getBytes(java.nio.charset.StandardCharsets.UTF_8); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(strBytesRaw.length)), + box(strBytesRaw)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("actual"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"default".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringWithMixedFields() { + // optional int32 count = 1 [default = 42]; + // optional string name = 2 [default = "test"]; + // Empty message should return both defaults + Byte[] row = new Byte[0]; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStr = ColumnVector.fromStrings("test"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedInt, expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{true, true}, + new long[]{42, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, "test".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testDefaultStringMultipleRows() { + // optional string name = 1 [default = "default"]; + // Multiple rows: empty, has value, empty + Byte[] row1 = new Byte[0]; // will use default + byte[] strBytesRaw = "row2val".getBytes(java.nio.charset.StandardCharsets.UTF_8); + Byte[] row2 = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(strBytesRaw.length)), + box(strBytesRaw)); + Byte[] row3 = new Byte[0]; // will use default + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row1, row2, row3}).build(); + ColumnVector expectedStr = ColumnVector.fromStrings("default", "row2val", "default"); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedStr); + ColumnVector actualStruct = decodeAllFieldsWithDefaults( + input.getColumn(0), + new int[]{1}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{true}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{"default".getBytes(java.nio.charset.StandardCharsets.UTF_8)}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + // ============================================================================ + // Tests for Nested and Repeated Fields (Phase 1-3 Implementation) + // ============================================================================ + + @Test + void testUnpackedRepeatedInt32() { + // Unpacked repeated: same field number appears multiple times + // message TestMsg { repeated int32 ids = 1; } + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(1, WT_VARINT)), box(encodeVarint(3))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + // Use the new nested API for repeated fields + // Field: ids (field_number=1, parent=-1, depth=0, wire_type=VARINT, type=INT32, repeated=true) + try (ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, // fieldNumbers + new int[]{-1}, // parentIndices (-1 = top level) + new int[]{0}, // depthLevels + new int[]{Protobuf.WT_VARINT}, // wireTypes + new int[]{DType.INT32.getTypeId().getNativeId()}, // outputTypeIds (element type) + new int[]{Protobuf.ENC_DEFAULT}, // encodings + new boolean[]{true}, // isRepeated + new boolean[]{false}, // isRequired + new boolean[]{false}, // hasDefaultValue + new long[]{0}, // defaultInts + new double[]{0.0}, // defaultFloats + new boolean[]{false}, // defaultBools + new byte[][]{null}, // defaultStrings + new int[][]{null}, // enumValidValues + false)) { // failOnErrors + // Result should be STRUCT> + // The list should contain [1, 2, 3] + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + } + + @Test + void testPackedRepeatedDoubleWithMultipleFields() { + // Test packed repeated fields with multiple types including edge cases. + // message WithPackedRepeated { + // optional int32 id = 1; + // repeated int32 int_values = 2 [packed=true]; + // repeated double double_values = 3 [packed=true]; + // repeated bool bool_values = 4 [packed=true]; + // } + + // Helper to build packed int data (varints) + java.io.ByteArrayOutputStream intBuf = new java.io.ByteArrayOutputStream(); + + // Row 0: id=42, int_values=[1,-1,100] (12 bytes packed), double_values=[1.5,2.5], bool=[true,false] + // Row 1: id=7, int_values=15x(-1) (150 bytes packed, 2-byte length varint!), double_values=[3.0,4.0], bool=[true] + // Row 2: id=0, int_values=[] (field omitted), double_values=[5.0], bool=[] (field omitted) + + // --- Row 0 --- + byte[] r0IntVarints = concatBytes(encodeVarint(1), encodeVarint(-1L & 0xFFFFFFFFFFFFFFFFL), encodeVarint(100)); + byte[] r0Doubles = concatBytes(encodeDouble(1.5), encodeDouble(2.5)); + byte[] r0Bools = new byte[]{0x01, 0x00}; + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(42)), + box(tag(2, WT_LEN)), box(encodeVarint(r0IntVarints.length)), box(r0IntVarints), + box(tag(3, WT_LEN)), box(encodeVarint(r0Doubles.length)), box(r0Doubles), + box(tag(4, WT_LEN)), box(encodeVarint(r0Bools.length)), box(r0Bools)); + + // --- Row 1: 15 negative ints => 150 bytes packed (length varint is 2 bytes: 0x96 0x01) --- + java.io.ByteArrayOutputStream buf1 = new java.io.ByteArrayOutputStream(); + byte[] negOneVarint = encodeVarint(-1L & 0xFFFFFFFFFFFFFFFFL); // 10 bytes + for (int i = 0; i < 15; i++) { + buf1.write(negOneVarint, 0, negOneVarint.length); + } + byte[] r1IntVarints = buf1.toByteArray(); // 150 bytes + byte[] r1Doubles = concatBytes(encodeDouble(3.0), encodeDouble(4.0)); + byte[] r1Bools = new byte[]{0x01}; + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(7)), + box(tag(2, WT_LEN)), box(encodeVarint(r1IntVarints.length)), box(r1IntVarints), + box(tag(3, WT_LEN)), box(encodeVarint(r1Doubles.length)), box(r1Doubles), + box(tag(4, WT_LEN)), box(encodeVarint(r1Bools.length)), box(r1Bools)); + + // --- Row 2: no int_values, no bool_values --- + byte[] r2Doubles = encodeDouble(5.0); + Byte[] row2 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(0)), + box(tag(3, WT_LEN)), box(encodeVarint(r2Doubles.length)), box(r2Doubles)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, row1, row2}).build()) { + try (ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1, 2, 3, 4}, + new int[]{-1, -1, -1, -1}, + new int[]{0, 0, 0, 0}, + new int[]{WT_VARINT, WT_VARINT, WT_64BIT, WT_VARINT}, + new int[]{ + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.FLOAT64.getTypeId().getNativeId(), + DType.BOOL8.getTypeId().getNativeId() + }, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, true, true, true}, + new boolean[]{false, false, false, false}, + new boolean[]{false, false, false, false}, + new long[]{0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false}, + new byte[][]{null, null, null, null}, + new int[][]{null, null, null, null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + assertEquals(3, result.getRowCount()); + + // Verify double_values child column has correct total count: 2 + 2 + 1 = 5 + try (ColumnVector doubleListCol = result.getChildColumnView(2).copyToColumnVector()) { + assertEquals(DType.LIST, doubleListCol.getType()); + try (ColumnVector doubleChildren = doubleListCol.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.FLOAT64, doubleChildren.getType()); + assertEquals(5, doubleChildren.getRowCount(), + "Total packed doubles across 3 rows should be 5, got " + doubleChildren.getRowCount()); + try (HostColumnVector hd = doubleChildren.copyToHost()) { + assertEquals(1.5, hd.getDouble(0), 1e-10); + assertEquals(2.5, hd.getDouble(1), 1e-10); + assertEquals(3.0, hd.getDouble(2), 1e-10); + assertEquals(4.0, hd.getDouble(3), 1e-10); + assertEquals(5.0, hd.getDouble(4), 1e-10); + } + } + } + } + } + } + + /** Helper: concatenate byte arrays */ + private static byte[] concatBytes(byte[]... arrays) { + int len = 0; + for (byte[] a : arrays) len += a.length; + byte[] out = new byte[len]; + int pos = 0; + for (byte[] a : arrays) { + System.arraycopy(a, 0, out, pos, a.length); + pos += a.length; + } + return out; + } + + @Test + void testNestedMessage() { + // message Inner { int32 x = 1; } + // message Outer { Inner inner = 1; } + // Outer with inner.x = 42 + Byte[] innerMessage = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(innerMessage.length)), + innerMessage); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + // Flattened schema: + // [0] inner: STRUCT, field_number=1, parent=-1, depth=0 + // [1] inner.x: INT32, field_number=1, parent=0, depth=1 + try (ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, // fieldNumbers + new int[]{-1, 0}, // parentIndices + new int[]{0, 1}, // depthLevels + new int[]{Protobuf.WT_LEN, Protobuf.WT_VARINT}, // wireTypes + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, // isRepeated + new boolean[]{false, false}, // isRequired + new boolean[]{false, false}, // hasDefaultValue + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + } + + @Test + void testDeepNestedMessageDepth3() { + // message Inner { int32 a = 1; string b = 2; bool c = 3; } + // message Middle { Inner inner = 1; int64 m = 2; } + // message Outer { Middle middle = 1; float score = 2; } + Byte[] innerMessage = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(7)), + box(tag(2, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes()), + box(tag(3, WT_VARINT)), new Byte[]{0x01}); + Byte[] middleMessage = concat( + box(tag(1, WT_LEN)), box(encodeVarint(innerMessage.length)), innerMessage, + box(tag(2, WT_VARINT)), box(encodeVarint(123L))); + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(middleMessage.length)), middleMessage, + box(tag(2, WT_32BIT)), box(encodeFloat(1.25f))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedA = ColumnVector.fromBoxedInts(7); + ColumnVector expectedB = ColumnVector.fromStrings("abc"); + ColumnVector expectedC = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedInner = ColumnVector.makeStruct(expectedA, expectedB, expectedC); + ColumnVector expectedM = ColumnVector.fromBoxedLongs(123L); + ColumnVector expectedMiddle = ColumnVector.makeStruct(expectedInner, expectedM); + ColumnVector expectedScore = ColumnVector.fromBoxedFloats(1.25f); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedMiddle, expectedScore); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 1, 2, 3, 2, 2}, // fieldNumbers + new int[]{-1, 0, 1, 1, 1, 0, -1}, // parentIndices + new int[]{0, 1, 2, 2, 2, 1, 0}, // depthLevels + new int[]{Protobuf.WT_LEN, Protobuf.WT_LEN, Protobuf.WT_VARINT, Protobuf.WT_LEN, + Protobuf.WT_VARINT, Protobuf.WT_VARINT, Protobuf.WT_32BIT}, // wireTypes + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId(), + DType.BOOL8.getTypeId().getNativeId(), DType.INT64.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId()}, // outputTypeIds + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT}, // encodings + new boolean[]{false, false, false, false, false, false, false}, // isRepeated + new boolean[]{false, false, false, false, false, false, false}, // isRequired + new boolean[]{false, false, false, false, false, false, false}, // hasDefaultValue + new long[]{0, 0, 0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false, false, false}, + new byte[][]{null, null, null, null, null, null, null}, + new int[][]{null, null, null, null, null, null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testPackedRepeatedInsideNestedMessage() { + // message Inner { repeated int32 ids = 1 [packed=true]; } + // message Outer { Inner inner = 1; } + byte[] packedIds = concatBytes(encodeVarint(10), encodeVarint(20), encodeVarint(30)); + Byte[] inner = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedIds.length)), + box(packedIds)); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(inner.length)), + inner); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, // outer.inner, inner.ids + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, true}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + false)) { + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector innerStruct = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector idsList = innerStruct.getChildColumnView(0).copyToColumnVector(); + ColumnVector ids = idsList.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostIds = ids.copyToHost()) { + assertEquals(3, ids.getRowCount()); + assertEquals(10, hostIds.getInt(0)); + assertEquals(20, hostIds.getInt(1)); + assertEquals(30, hostIds.getInt(2)); + } + } + } + + @Test + void testPackedRepeatedChildInsideRepeatedMessage() { + // message Item { repeated int32 ids = 1 [packed=true]; optional int32 score = 2; } + // message Outer { repeated Item items = 1; } + byte[] item0Ids = concatBytes(encodeVarint(10), encodeVarint(20)); + Byte[] item0 = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(item0Ids.length)), + box(item0Ids), + box(tag(2, WT_VARINT)), + box(encodeVarint(7))); + byte[] item1Ids = concatBytes(encodeVarint(30)); + Byte[] item1 = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(item1Ids.length)), + box(item1Ids), + box(tag(2, WT_VARINT)), + box(encodeVarint(9))); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(item0.length)), + item0, + box(tag(1, WT_LEN)), + box(encodeVarint(item1.length)), + item1); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedItems = ColumnVector.fromLists( + new ListType(true, + new StructType(true, + new ListType(true, new BasicType(true, DType.INT32)), + new BasicType(true, DType.INT32))), + Arrays.asList( + new StructData(Arrays.asList(10, 20), 7), + new StructData(Arrays.asList(30), 9))); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedItems); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{ + DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId() + }, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, true, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, null, null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testPermissiveRepeatedWrongWireTypeDoesNotCorruptFollowingRow() { + // message Msg { repeated int32 ids = 1; } + // Row 0 has one valid element, then a malformed fixed32 occurrence for the same field, + // then another valid varint that must be ignored once the row is marked malformed. + // Row 1 must keep its own slot and not be overwritten by row 0's trailing occurrence. + Byte[] row0 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(1, WT_32BIT)), box(encodeFixed32(77)), + box(tag(1, WT_VARINT)), box(encodeVarint(2))); + Byte[] row1 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(100))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row0, row1}).build(); + ColumnVector expectedIds = ColumnVector.fromLists( + new ListType(true, new BasicType(true, DType.INT32)), + Arrays.asList(1), + Arrays.asList(100)); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedIds); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testRepeatedUint32() { + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(1, WT_VARINT)), box(encodeVarint(3))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.UINT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector vals = list.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.UINT32, vals.getType()); + assertEquals(3, vals.getRowCount()); + } + } + } + + @Test + void testRepeatedUint64() { + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(11)), + box(tag(1, WT_VARINT)), box(encodeVarint(22)), + box(tag(1, WT_VARINT)), box(encodeVarint(33))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.UINT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector vals = list.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.UINT64, vals.getType()); + assertEquals(3, vals.getRowCount()); + } + } + } + + @Test + void testWireTypeMismatchInRepeatedMessageChildFailfast() { + // message Item { int32 x = 1; } message Outer { repeated Item items = 1; } + // Encode x with WT_64BIT instead of WT_VARINT to force hard mismatch. + Byte[] badItem = concat(box(tag(1, WT_64BIT)), box(encodeFixed64(123L))); + Byte[] row = concat(box(tag(1, WT_LEN)), box(encodeVarint(badItem.length)), badItem); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector ignored = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + true)) { + } + }); + } + } + + // ============================================================================ + // FAILFAST Mode Tests (failOnErrors = true) + // ============================================================================ + + @Test + void testFailfastMalformedVarint() { + // Varint that never terminates (all continuation bits set) + Byte[] malformed = new Byte[]{(byte)0x08, (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, + (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF}; + try (Table input = new Table.TestBuilder().column(new Byte[][]{malformed}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { // failOnErrors = true + } + }); + } + } + + @Test + void testFailfastTruncatedVarint() { + // Single byte with continuation bit set but no following byte + Byte[] truncated = concat(box(tag(1, WT_VARINT)), new Byte[]{(byte)0x80}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFailfastTruncatedString() { + // String field with length=5 but no actual data + Byte[] truncated = concat(box(tag(2, WT_LEN)), box(encodeVarint(5))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{2}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFailfastTruncatedFixed32() { + // Fixed32 needs 4 bytes but only 3 provided + Byte[] truncated = concat(box(tag(1, WT_32BIT)), new Byte[]{0x01, 0x02, 0x03}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + true)) { + } + }); + } + } + + @Test + void testFailfastTruncatedFixed64() { + // Fixed64 needs 8 bytes but only 5 provided + Byte[] truncated = concat(box(tag(1, WT_64BIT)), new Byte[]{0x01, 0x02, 0x03, 0x04, 0x05}); + try (Table input = new Table.TestBuilder().column(new Byte[][]{truncated}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + true)) { + } + }); + } + } + + @Test + void testFailfastWrongWireType() { + // Field 1 with wire type 2 (length-delimited), but we request varint + Byte[] row = concat(box(tag(1, WT_LEN)), box(encodeVarint(3)), box("abc".getBytes())); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + true)) { + } + }); + } + } + + @Test + void testFailfastFieldNumberZero() { + // Field number 0 is invalid in protobuf + Byte[] row = concat(box(tag(0, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + } + }); + } + } + + @Test + void testFailfastFieldNumberAboveSpecLimit() { + // Protobuf field numbers must be <= 2^29 - 1. + Byte[] row = concat(box(tag(1 << 29, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(ai.rapids.cudf.CudfException.class, () -> { + try (ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + true)) { + } + }); + } + } + + @Test + void testUnknownEndGroupWireTypeNullsMalformedRow() { + Byte[] row = concat( + box(tag(5, 4)), + box(tag(1, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + false)) { + assertSingleNullStructRow(actual, "Unknown end-group wire type should null the struct row"); + } + } -/** - * Tests for the Protobuf GPU decoder — framework PR. - * - * These tests verify the decode stub: schema validation, correct output shape, - * null column construction, and empty-row handling. Actual data extraction tests - * are added in follow-up PRs. - */ -public class ProtobufTest { + @Test + void testFailfastValidDataDoesNotThrow() { + // Valid protobuf should not throw even with failOnErrors = true + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(42))); + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeAllFields( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{0}, + true)) { + try (ColumnVector expected = ColumnVector.fromBoxedLongs(42L); + ColumnVector expectedStruct = ColumnVector.makeStruct(expected)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, result); + } + } + } - private static ProtobufSchemaDescriptor makeScalarSchema(int[] fieldNumbers, int[] typeIds, - int[] encodings) { - int n = fieldNumbers.length; - int[] parentIndices = new int[n]; - int[] depthLevels = new int[n]; - int[] wireTypes = new int[n]; - boolean[] isRepeated = new boolean[n]; - boolean[] isRequired = new boolean[n]; - boolean[] hasDefault = new boolean[n]; - long[] defaultInts = new long[n]; - double[] defaultFloats = new double[n]; - boolean[] defaultBools = new boolean[n]; - byte[][] defaultStrings = new byte[n][]; - int[][] enumValid = new int[n][]; - byte[][][] enumNames = new byte[n][][]; + // ============================================================================ + // Performance Benchmark Tests (Multi-field) + // ============================================================================ + + @Test + void testMultiFieldPerformance() { + // Test with 6 fields to verify fused kernel efficiency + // message Msg { bool f1=1; int32 f2=2; int64 f3=3; float f4=4; double f5=5; string f6=6; } + Byte[] row = concat( + box(tag(1, WT_VARINT)), new Byte[]{0x01}, + box(tag(2, WT_VARINT)), box(encodeVarint(12345)), + box(tag(3, WT_VARINT)), box(encodeVarint(9876543210L)), + box(tag(4, WT_32BIT)), box(encodeFloat(3.14f)), + box(tag(5, WT_64BIT)), box(encodeDouble(2.71828)), + box(tag(6, WT_LEN)), box(encodeVarint(5)), box("hello".getBytes())); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFields( + input.getColumn(0), + new int[]{1, 2, 3, 4, 5, 6}, + new int[]{ + DType.BOOL8.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT64.getTypeId().getNativeId(), + DType.FLOAT32.getTypeId().getNativeId(), + DType.FLOAT64.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId()}, + new int[]{0, 0, 0, 0, 0, 0})) { + try (ColumnVector expectedBool = ColumnVector.fromBoxedBooleans(true); + ColumnVector expectedInt = ColumnVector.fromBoxedInts(12345); + ColumnVector expectedLong = ColumnVector.fromBoxedLongs(9876543210L); + ColumnVector expectedFloat = ColumnVector.fromBoxedFloats(3.14f); + ColumnVector expectedDouble = ColumnVector.fromBoxedDoubles(2.71828); + ColumnVector expectedString = ColumnVector.fromStrings("hello"); + ColumnVector expectedStruct = ColumnVector.makeStruct( + expectedBool, expectedInt, expectedLong, expectedFloat, expectedDouble, expectedString)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + } + + // ============================================================================ + // Enum Validation Tests + // ============================================================================ + + /** + * Helper method that wraps decodeToStruct with enum validation support. + */ + private static ColumnVector decodeAllFieldsWithEnums(ColumnView binaryInput, + int[] fieldNumbers, + int[] typeIds, + int[] encodings, + int[][] enumValidValues, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + return decodeScalarFields(binaryInput, fieldNumbers, typeIds, encodings, + new boolean[numFields], new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + enumValidValues, failOnErrors); + } + /** + * Helper that enables enum-as-string decoding by passing enum name mappings. + */ + private static ColumnVector decodeAllFieldsWithEnumStrings(ColumnView binaryInput, + int[] fieldNumbers, + int[][] enumValidValues, + byte[][][] enumNames, + boolean failOnErrors) { + int numFields = fieldNumbers.length; + int[] typeIds = new int[numFields]; + int[] encodings = new int[numFields]; + for (int i = 0; i < numFields; i++) { + typeIds[i] = DType.STRING.getTypeId().getNativeId(); + encodings[i] = Protobuf.ENC_ENUM_STRING; + } + int[] parentIndices = new int[numFields]; + int[] depthLevels = new int[numFields]; + int[] wireTypes = new int[numFields]; + boolean[] isRepeated = new boolean[numFields]; java.util.Arrays.fill(parentIndices, -1); - for (int i = 0; i < n; i++) { - wireTypes[i] = deriveWireType(typeIds[i], encodings[i]); + java.util.Arrays.fill(wireTypes, Protobuf.WT_VARINT); + return Protobuf.decodeToStruct(binaryInput, + new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, + wireTypes, typeIds, encodings, isRepeated, + new boolean[numFields], new boolean[numFields], new long[numFields], + new double[numFields], new boolean[numFields], new byte[numFields][], + enumValidValues, enumNames), + failOnErrors); + } + + @Test + void testEnumAsStringValidValue() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // GREEN + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedField = ColumnVector.fromStrings("GREEN"); + ColumnVector expected = ColumnVector.makeStruct(expectedField); + ColumnVector actual = decodeAllFieldsWithEnumStrings( + input.getColumn(0), + new int[]{1}, + new int[][]{{0, 1, 2}}, + enumNames, + false)) { + AssertUtils.assertStructColumnsAreEqual(expected, actual); } - return new ProtobufSchemaDescriptor(fieldNumbers, parentIndices, depthLevels, - wireTypes, typeIds, encodings, isRepeated, isRequired, hasDefault, - defaultInts, defaultFloats, defaultBools, defaultStrings, enumValid, enumNames); } - private static int deriveWireType(int typeId, int encoding) { - if (encoding == Protobuf.ENC_ENUM_STRING) return Protobuf.WT_VARINT; - if (typeId == DType.FLOAT32.getTypeId().getNativeId()) return Protobuf.WT_32BIT; - if (typeId == DType.FLOAT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; - if (typeId == DType.STRING.getTypeId().getNativeId()) return Protobuf.WT_LEN; - if (typeId == DType.LIST.getTypeId().getNativeId()) return Protobuf.WT_LEN; - if (typeId == DType.STRUCT.getTypeId().getNativeId()) return Protobuf.WT_LEN; - if (encoding == Protobuf.ENC_FIXED) { - if (typeId == DType.INT64.getTypeId().getNativeId()) return Protobuf.WT_64BIT; - return Protobuf.WT_32BIT; + @Test + void testEnumAsStringUnknownValueReturnsNullRow() { + // Unknown enum value should null the entire struct row (PERMISSIVE behavior). + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeAllFieldsWithEnumStrings( + input.getColumn(0), + new int[]{1}, + new int[][]{{0, 1, 2}}, + enumNames, + false); + HostColumnVector hostStruct = actual.copyToHost()) { + assertEquals(1, actual.getNullCount(), "Struct row should be null for unknown enum value"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null"); + } + } + + @Test + void testEnumAsStringMixedValidAndUnknown() { + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); // RED + Byte[] row1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // unknown + Byte[] row2 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(2))); // BLUE + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(row0, row1, row2).build(); + ColumnVector actual = decodeAllFieldsWithEnumStrings( + input.getColumn(0), + new int[]{1}, + new int[][]{{0, 1, 2}}, + enumNames, + false); + HostColumnVector hostStruct = actual.copyToHost()) { + assertEquals(1, actual.getNullCount(), "Only the unknown enum row should be null"); + assertTrue(!hostStruct.isNull(0), "Row 0 should be valid"); + assertTrue(hostStruct.isNull(1), "Row 1 should be null"); + assertTrue(!hostStruct.isNull(2), "Row 2 should be valid"); + } + } + + @Test + void testEnumValidValue() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Msg { Color color = 1; } + // Test with valid enum value (GREEN = 1) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedColor = ColumnVector.fromBoxedInts(1); // GREEN + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedColor); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values: RED=0, GREEN=1, BLUE=2 + false)) { + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testEnumUnknownValueReturnsNullRow() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Msg { Color color = 1; } + // Test with unknown enum value (999 is not defined) + // The entire struct row should be null (matching Spark CPU PERMISSIVE mode) + Byte[] row = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values: RED=0, GREEN=1, BLUE=2 + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + // The struct itself should be null (not just the field) + assertEquals(1, actualStruct.getNullCount(), "Struct row should be null for unknown enum"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null"); + } + } + + @Test + void testEnumMixedValidAndUnknown() { + // Test multiple rows with mix of valid and unknown enum values + // Rows with unknown enum values should have null struct (not just null field) + Byte[] row0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); // RED (valid) -> struct valid + Byte[] row1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // unknown -> struct null + Byte[] row2 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(2))); // BLUE (valid) -> struct valid + Byte[] row3 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(-1))); // negative (unknown) -> struct null + + try (Table input = new Table.TestBuilder().column(row0, row1, row2, row3).build(); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + // Check struct-level nulls + assertEquals(2, actualStruct.getNullCount(), "Should have 2 null rows (rows 1 and 3)"); + assertTrue(!hostStruct.isNull(0), "Row 0 should be valid"); + assertTrue(hostStruct.isNull(1), "Row 1 should be null (unknown enum 999)"); + assertTrue(!hostStruct.isNull(2), "Row 2 should be valid"); + assertTrue(hostStruct.isNull(3), "Row 3 should be null (unknown enum -1)"); + } + } + + @Test + void testEnumWithOtherFields_NullsEntireRow() { + // message Msg { Color color = 1; int32 count = 2; } + // Test that unknown enum value nulls the ENTIRE struct row (not just the enum field) + // This matches Spark CPU PERMISSIVE mode behavior + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(999)), // unknown enum value + box(tag(2, WT_VARINT)), box(encodeVarint(42))); // count = 42 + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}, null}, // first field is enum, second is regular int (no validation) + false); + HostColumnVector hostStruct = actualStruct.copyToHost()) { + // The entire struct row should be null + assertEquals(1, actualStruct.getNullCount(), "Struct row should be null"); + assertTrue(hostStruct.isNull(0), "Row 0 should be null due to unknown enum"); + } + } + + @Test + void testRepeatedStructEnumInvalidKeepsTopLevelRowValid() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Item { Color color = 1; } + // message Msg { repeated Item items = 1; } + Byte[] item00 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(0))); // valid + Byte[] item01 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // invalid + Byte[] row0 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item00.length)), item00, + box(tag(1, WT_LEN)), box(encodeVarint(item01.length)), item01); + Byte[] item10 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // valid + Byte[] row1 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item10.length)), item10); + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector actualStruct = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + false); + ColumnVector items = actualStruct.getChildColumnView(0).copyToColumnVector(); + ColumnVector itemStructs = items.getChildColumnView(0).copyToColumnVector(); + ColumnVector colors = itemStructs.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actualStruct.copyToHost(); + HostColumnVector hostColors = colors.copyToHost()) { + assertEquals(0, actualStruct.getNullCount(), "Invalid child enum should not null the top-level row"); + assertFalse(hostStruct.isNull(0), "Row 0 should remain valid"); + assertFalse(hostStruct.isNull(1), "Row 1 should remain valid"); + assertEquals(3, colors.getRowCount(), "All repeated message elements should remain visible"); + assertEquals(1, colors.getNullCount(), "Only the invalid enum field should be null"); + assertEquals(0, hostColors.getInt(0), "The first repeated child should keep its valid enum"); + assertTrue(hostColors.isNull(1), "The invalid repeated child enum should decode as null"); + assertEquals(1, hostColors.getInt(2), "The second row should keep its valid enum"); + } + } + + @Test + void testRepeatedStructEnumInvalidKeepsSiblingFieldsVisible() { + // enum Color { RED=0; GREEN=1; BLUE=2; } + // message Item { Color color = 1; int32 count = 2; } + // message Msg { repeated Item items = 1; } + Byte[] item00 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(0)), + box(tag(2, WT_VARINT)), box(encodeVarint(10))); + Byte[] item01 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(999)), + box(tag(2, WT_VARINT)), box(encodeVarint(20))); + Byte[] row0 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item00.length)), item00, + box(tag(1, WT_LEN)), box(encodeVarint(item01.length)), item01); + Byte[] item10 = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(2, WT_VARINT)), box(encodeVarint(30))); + Byte[] row1 = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item10.length)), item10); + + try (Table input = new Table.TestBuilder().column(row0, row1).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1, 2}, + new int[]{-1, 0, 0}, + new int[]{0, 1, 1}, + new int[]{WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{true, false, false}, + new boolean[]{false, false, false}, + new boolean[]{false, false, false}, + new long[]{0, 0, 0}, + new double[]{0.0, 0.0, 0.0}, + new boolean[]{false, false, false}, + new byte[][]{null, null, null}, + new int[][]{null, new int[]{0, 1, 2}, null}, + false); + ColumnView itemsView = actual.getChildColumnView(0); + ColumnView itemStructView = itemsView.getChildColumnView(0); + ColumnView colorView = itemStructView.getChildColumnView(0); + ColumnView countView = itemStructView.getChildColumnView(1); + ColumnVector colorVector = colorView.copyToColumnVector(); + ColumnVector countVector = countView.copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostColors = colorVector.copyToHost(); + HostColumnVector hostCounts = countVector.copyToHost()) { + HostColumnVectorCore hostItems = hostStruct.getChildColumnView(0); + + assertEquals(0, actual.getNullCount(), "Invalid child enum should not null the parent row"); + assertFalse(hostStruct.isNull(0), "Row 0 should remain valid"); + assertFalse(hostStruct.isNull(1), "Row 1 should remain valid"); + + assertEquals(0, hostItems.getNullCount(), "LIST rows should remain valid"); + assertFalse(hostItems.isNull(0), "items[0] should remain valid"); + assertFalse(hostItems.isNull(1), "items[1] should remain valid"); + + assertEquals(3, itemStructView.getRowCount(), + "All repeated message elements should remain visible"); + assertEquals(0, itemStructView.getNullCount(), + "No repeated struct element should be dropped"); + assertEquals(1, colorView.getNullCount(), + "Only the invalid enum child should be null"); + assertEquals(0, hostColors.getInt(0), + "The first repeated child should keep its valid enum"); + assertTrue(hostColors.isNull(1), + "The invalid repeated child enum should decode as null"); + assertEquals(1, hostColors.getInt(2), + "The second row should keep its valid enum"); + assertEquals(3, countView.getRowCount(), + "Sibling fields should remain visible for every repeated element"); + assertEquals(0, countView.getNullCount(), + "Sibling scalar fields should stay non-null when only the enum is invalid"); + assertEquals(10, hostCounts.getInt(0)); + assertEquals(20, hostCounts.getInt(1)); + assertEquals(30, hostCounts.getInt(2)); + } + } + + @Test + void testEnumMissingFieldDoesNotNullRow() { + // Missing enum field should return null for the field, but NOT null the entire row + // Only unknown enum values (present but invalid) trigger row-level null + Byte[] row = new Byte[0]; // empty message + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedColor = ColumnVector.fromBoxedInts((Integer) null); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedColor); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}}, // valid enum values + false)) { + // Struct row should be valid (not null), only the field is null + assertEquals(0, actualStruct.getNullCount(), "Struct row should NOT be null for missing field"); + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); + } + } + + @Test + void testNestedEnumInvalidKeepsRowAndSiblingFieldsInPermissiveMode() { + // message WithNestedEnum { + // optional int32 id = 1; + // optional Detail detail = 2; + // optional string name = 3; + // } + // message Detail { + // enum Status { UNKNOWN = 0; OK = 1; BAD = 2; } + // optional Status status = 1; + // optional int32 count = 2; + // } + // Invalid nested enum should null the whole row, including grandchild field detail.count. + Byte[] detail = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(999)), + box(tag(2, WT_VARINT)), box(encodeVarint(20))); + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(2, WT_LEN)), box(encodeVarint(detail.length)), detail, + box(tag(3, WT_LEN)), box(encodeVarint(3)), box("bad".getBytes())); + + byte[][][] enumNames = new byte[][][] { + null, + null, + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "OK".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAD".getBytes(java.nio.charset.StandardCharsets.UTF_8) + }, + null + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 2, 3, 1, 2}, + new int[]{-1, -1, -1, 1, 1}, + new int[]{0, 0, 0, 1, 1}, + new int[]{WT_VARINT, WT_LEN, WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.STRUCT.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_ENUM_STRING, + Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new long[]{0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false}, + new byte[][]{null, null, null, null, null}, + new int[][]{null, null, null, new int[]{0, 1, 2}, null}, + enumNames, + false); + ColumnVector detailCol = actual.getChildColumnView(1).copyToColumnVector(); + ColumnVector statusCol = detailCol.getChildColumnView(0).copyToColumnVector(); + ColumnVector countCol = detailCol.getChildColumnView(1).copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostDetail = detailCol.copyToHost(); + HostColumnVector hostStatus = statusCol.copyToHost(); + HostColumnVector hostCount = countCol.copyToHost()) { + assertEquals(0, actual.getNullCount(), "Invalid nested enum should not null the top-level row"); + assertFalse(hostStruct.isNull(0), "Top-level struct should remain valid"); + assertEquals(0, detailCol.getNullCount(), "Nested struct should remain present"); + assertFalse(hostDetail.isNull(0), "Nested struct row should remain valid"); + assertEquals(1, statusCol.getNullCount(), "Only the invalid enum field should be null"); + assertTrue(hostStatus.isNull(0), "detail.status should decode as null"); + assertEquals(0, countCol.getNullCount(), "Sibling nested fields should remain visible"); + assertFalse(hostCount.isNull(0), "detail.count should remain valid"); + assertEquals(20, hostCount.getInt(0), "detail.count should preserve the decoded value"); + } + } + + @Test + void testMalformedNestedEnumPermissiveNullsWholeRow() { + // message WithNestedEnum { + // optional int32 id = 1; + // optional Detail detail = 2; + // optional string name = 3; + // } + // message Detail { + // enum Status { UNKNOWN = 0; OK = 1; BAD = 2; } + // optional Status status = 1; + // optional int32 count = 2; + // } + // + // The nested message length is intentionally truncated to 4 bytes. Spark CPU treats this as a + // malformed row in PERMISSIVE mode and returns a null struct row rather than partial data. + Byte[] rowValid = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(2, WT_LEN)), box(encodeVarint(4)), + box(tag(1, WT_VARINT)), box(encodeVarint(1)), + box(tag(2, WT_VARINT)), box(encodeVarint(10)), + box(tag(3, WT_LEN)), box(encodeVarint(2)), box("ok".getBytes())); + Byte[] rowInvalid = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(2)), + box(tag(2, WT_LEN)), box(encodeVarint(4)), + box(tag(1, WT_VARINT)), box(encodeVarint(999)), + box(tag(2, WT_VARINT)), box(encodeVarint(20)), + box(tag(3, WT_LEN)), box(encodeVarint(3)), box("bad".getBytes())); + + try (Table input = new Table.TestBuilder().column(rowValid, rowInvalid).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 2, 3, 1, 2}, + new int[]{-1, -1, -1, 1, 1}, + new int[]{0, 0, 0, 1, 1}, + new int[]{WT_VARINT, WT_LEN, WT_LEN, WT_VARINT, WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId(), + DType.STRUCT.getTypeId().getNativeId(), + DType.STRING.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId(), + DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT, + Protobuf.ENC_DEFAULT}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new boolean[]{false, false, false, false, false}, + new long[]{0, 0, 0, 0, 0}, + new double[]{0.0, 0.0, 0.0, 0.0, 0.0}, + new boolean[]{false, false, false, false, false}, + new byte[][]{null, null, null, null, null}, + new int[][]{null, null, null, new int[]{0, 1, 2}, null}, + false); + HostColumnVector hostStruct = actual.copyToHost()) { + assertEquals(1, actual.getNullCount(), "Only the malformed row should be null"); + assertFalse(hostStruct.isNull(0), "The valid row should remain decoded"); + assertTrue(hostStruct.isNull(1), "The malformed nested row should be null in PERMISSIVE mode"); + } + } + + @Test + void testEnumValidWithOtherFields() { + // message Msg { Color color = 1; int32 count = 2; } + // Test that valid enum value works correctly with other fields + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(1)), // GREEN (valid) + box(tag(2, WT_VARINT)), box(encodeVarint(42))); // count = 42 + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector expectedColor = ColumnVector.fromBoxedInts(1); + ColumnVector expectedCount = ColumnVector.fromBoxedInts(42); + ColumnVector expectedStruct = ColumnVector.makeStruct(expectedColor, expectedCount); + ColumnVector actualStruct = decodeAllFieldsWithEnums( + input.getColumn(0), + new int[]{1, 2}, + new int[]{DType.INT32.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new int[][]{{0, 1, 2}, null}, // first field is enum, second is regular int + false)) { + // Struct row should be valid with correct values + assertEquals(0, actualStruct.getNullCount(), "Struct row should be valid"); + AssertUtils.assertStructColumnsAreEqual(expectedStruct, actualStruct); } - return Protobuf.WT_VARINT; } // ============================================================================ - // Output shape tests — verify the stub produces correctly typed struct columns + // Repeated Enum-as-String Tests // ============================================================================ @Test - void testEmptySchemaProducesEmptyStruct() { - Byte[] row = new Byte[]{0x08, 0x01}; + void testRepeatedEnumAsString() { + // repeated Color colors = 1; with Color { RED=0; GREEN=1; BLUE=2; } + // Row with three occurrences: RED, BLUE, GREEN + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(0)), // RED + box(tag(1, WT_VARINT)), box(encodeVarint(2)), // BLUE + box(tag(1, WT_VARINT)), box(encodeVarint(1))); // GREEN + + byte[][][] enumNames = new byte[][][] { + new byte[][] { + "RED".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "GREEN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BLUE".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), - makeScalarSchema(new int[]{}, new int[]{}, new int[]{}), true)) { - assertNotNull(result); - assertEquals(DType.STRUCT, result.getType()); - assertEquals(1, result.getRowCount()); - assertEquals(0, result.getNumChildren()); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{Protobuf.WT_VARINT}, + new int[]{DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_ENUM_STRING}, + new boolean[]{true}, // isRepeated + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{{0, 1, 2}}, + enumNames, + false)) { + assertNotNull(actual); + assertEquals(DType.STRUCT, actual.getType()); + assertEquals(1, actual.getNumChildren()); + try (ColumnView listCol = actual.getChildColumnView(0)) { + assertEquals(DType.LIST, listCol.getType()); + try (ColumnView strChild = listCol.getChildColumnView(0); + HostColumnVector hostStrs = strChild.copyToHost()) { + assertEquals(3, hostStrs.getRowCount()); + assertEquals("RED", hostStrs.getJavaString(0)); + assertEquals("BLUE", hostStrs.getJavaString(1)); + assertEquals("GREEN", hostStrs.getJavaString(2)); + } + } } } @Test - void testSingleScalarFieldOutputShape() { - Byte[] row = new Byte[]{0x08, 0x01}; + void testRepeatedMessageChildEnumAsString() { + // message Item { optional Priority priority = 1; } + // message Outer { repeated Item items = 1; } + // enum Priority { UNKNOWN=0; FOO=1; BAR=2; } + Byte[] item0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // FOO + Byte[] item1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(2))); // BAR + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item0.length)), item0, + box(tag(1, WT_LEN)), box(encodeVarint(item1.length)), item1); + + byte[][][] enumNames = new byte[][][] { + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "FOO".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAR".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); - ColumnVector result = Protobuf.decodeToStruct(input.getColumn(0), - makeScalarSchema( - new int[]{1}, - new int[]{DType.INT64.getTypeId().getNativeId()}, - new int[]{Protobuf.ENC_DEFAULT}), true)) { + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_ENUM_STRING}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + enumNames, + false); + ColumnVector items = actual.getChildColumnView(0).copyToColumnVector(); + ColumnVector itemStructs = items.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorities = itemStructs.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostPriorities = priorities.copyToHost()) { + assertEquals(2, priorities.getRowCount()); + assertEquals("FOO", hostPriorities.getJavaString(0)); + assertEquals("BAR", hostPriorities.getJavaString(1)); + } + } + + @Test + void testRepeatedMessageChildEnumAsStringInvalidKeepsRowValid() { + Byte[] item0 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(1))); // FOO + Byte[] item1 = concat(box(tag(1, WT_VARINT)), box(encodeVarint(999))); // invalid + Byte[] row = concat( + box(tag(1, WT_LEN)), box(encodeVarint(item0.length)), item0, + box(tag(1, WT_LEN)), box(encodeVarint(item1.length)), item1); + + byte[][][] enumNames = new byte[][][] { + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "FOO".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAR".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_ENUM_STRING}, + new boolean[]{true, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + enumNames, + false); + ColumnVector items = actual.getChildColumnView(0).copyToColumnVector(); + ColumnVector itemStructs = items.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorities = itemStructs.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostStruct = actual.copyToHost(); + HostColumnVector hostPriorities = priorities.copyToHost()) { + assertEquals(0, actual.getNullCount(), "Invalid child enum should not null the top-level row"); + assertFalse(hostStruct.isNull(0), "The top-level row should remain valid"); + assertEquals(2, priorities.getRowCount(), "Both repeated message elements should remain visible"); + assertEquals(1, priorities.getNullCount(), "Only the invalid enum field should be null"); + assertEquals("FOO", hostPriorities.getJavaString(0)); + assertTrue(hostPriorities.isNull(1), "The invalid repeated child enum should decode as null"); + } + } + + @Test + void testNestedRepeatedEnumAsString() { + // message Inner { repeated Priority priority = 1; } + // message Outer { optional Inner inner = 1; } + // enum Priority { UNKNOWN=0; FOO=1; BAR=2; } + byte[] packedPriorities = concatBytes(encodeVarint(0), encodeVarint(2), encodeVarint(1)); + Byte[] inner = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedPriorities.length)), + box(packedPriorities)); + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(inner.length)), + inner); + + byte[][][] enumNames = new byte[][][] { + null, + new byte[][] { + "UNKNOWN".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "FOO".getBytes(java.nio.charset.StandardCharsets.UTF_8), + "BAR".getBytes(java.nio.charset.StandardCharsets.UTF_8) + } + }; + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector actual = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.STRING.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_ENUM_STRING}, + new boolean[]{false, true}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, new int[]{0, 1, 2}}, + enumNames, + false); + ColumnVector innerStruct = actual.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorityList = innerStruct.getChildColumnView(0).copyToColumnVector(); + ColumnVector priorities = priorityList.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostPriorities = priorities.copyToHost()) { + assertEquals(3, priorities.getRowCount()); + assertEquals("UNKNOWN", hostPriorities.getJavaString(0)); + assertEquals("BAR", hostPriorities.getJavaString(1)); + assertEquals("FOO", hostPriorities.getJavaString(2)); + } + } + + // ============================================================================ + // Edge case and boundary tests + // ============================================================================ + + @Test + void testPackedFixedMisaligned() { + byte[] packedData = new byte[]{0x01, 0x02, 0x03, 0x04, 0x05}; + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedData.length)), + box(packedData)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(RuntimeException.class, () -> { + try (ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_32BIT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + true)) { + } + }); + } + } + + @Test + void testPackedFixedMisaligned64() { + byte[] packedData = new byte[]{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09}; + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(packedData.length)), + box(packedData)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build()) { + assertThrows(RuntimeException.class, () -> { + try (ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_64BIT}, + new int[]{DType.INT64.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_FIXED}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + true)) { + } + }); + } + } + + @Test + void testLargeRepeatedField() throws Exception { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + for (int i = 0; i < 100000; i++) { + baos.write(tag(1, WT_VARINT)); + baos.write(encodeVarint(i)); + } + Byte[] row = box(baos.toByteArray()); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { assertNotNull(result); assertEquals(DType.STRUCT, result.getType()); - assertEquals(1, result.getRowCount()); - assertEquals(1, result.getNumChildren()); - assertEquals(DType.INT64, result.getChildColumnView(0).getType()); + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector()) { + assertEquals(DType.LIST, list.getType()); + } } } @@ -169,6 +3473,46 @@ void testNullInputRowProducesNullStructRow() { } } + @Test + void testMixedPackedUnpacked() { + byte[] packedContent = concatBytes(encodeVarint(30), encodeVarint(40)); + Byte[] row = concat( + box(tag(1, WT_VARINT)), box(encodeVarint(10)), + box(tag(1, WT_VARINT)), box(encodeVarint(20)), + box(tag(1, WT_LEN)), box(encodeVarint(packedContent.length)), box(packedContent)); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector list = result.getChildColumnView(0).copyToColumnVector(); + ColumnVector vals = list.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostVals = vals.copyToHost()) { + assertEquals(4, vals.getRowCount()); + assertEquals(10, hostVals.getInt(0)); + assertEquals(20, hostVals.getInt(1)); + assertEquals(30, hostVals.getInt(2)); + assertEquals(40, hostVals.getInt(3)); + } + } + } + @Test void testAllNullInputRows() { try (Table input = new Table.TestBuilder().column(new Byte[][]{null, null, null}).build(); @@ -188,6 +3532,100 @@ void testAllNullInputRows() { } } + @Test + void testLargeFieldNumber() { + int maxFieldNumber = (1 << 29) - 1; + Byte[] row = concat( + box(tag(maxFieldNumber, WT_VARINT)), + box(encodeVarint(42))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{maxFieldNumber}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{false}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + try (ColumnVector child = result.getChildColumnView(0).copyToColumnVector(); + HostColumnVector hostChild = child.copyToHost()) { + assertEquals(42, hostChild.getInt(0)); + } + } + } + + private void verifyDeepNesting(int numLevels) { + int numFields = 2 * numLevels - 1; + + byte[] current = concatBytes(tag(1, WT_VARINT), encodeVarint(1)); + for (int level = numLevels - 2; level >= 0; level--) { + current = concatBytes( + tag(1, WT_VARINT), encodeVarint(1), + tag(2, WT_LEN), encodeVarint(current.length), current); + } + Byte[] row = box(current); + + int[] fieldNumbers = new int[numFields]; + int[] parentIndices = new int[numFields]; + int[] depthLevels = new int[numFields]; + int[] wireTypes = new int[numFields]; + int[] outputTypeIds = new int[numFields]; + int[] encodings = new int[numFields]; + boolean[] isRepeated = new boolean[numFields]; + boolean[] isRequired = new boolean[numFields]; + boolean[] hasDefaultValue = new boolean[numFields]; + long[] defaultInts = new long[numFields]; + double[] defaultFloats = new double[numFields]; + boolean[] defaultBools = new boolean[numFields]; + byte[][] defaultStrings = new byte[numFields][]; + int[][] enumValidValues = new int[numFields][]; + + for (int level = 0; level < numLevels; level++) { + int intIdx = 2 * level; + int parentIdx = level == 0 ? -1 : 2 * (level - 1) + 1; + + fieldNumbers[intIdx] = 1; + parentIndices[intIdx] = parentIdx; + depthLevels[intIdx] = level; + wireTypes[intIdx] = WT_VARINT; + outputTypeIds[intIdx] = DType.INT32.getTypeId().getNativeId(); + encodings[intIdx] = Protobuf.ENC_DEFAULT; + + if (level < numLevels - 1) { + int structIdx = 2 * level + 1; + fieldNumbers[structIdx] = 2; + parentIndices[structIdx] = parentIdx; + depthLevels[structIdx] = level; + wireTypes[structIdx] = WT_LEN; + outputTypeIds[structIdx] = DType.STRUCT.getTypeId().getNativeId(); + encodings[structIdx] = Protobuf.ENC_DEFAULT; + } + } + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + fieldNumbers, parentIndices, depthLevels, wireTypes, + outputTypeIds, encodings, isRepeated, isRequired, + hasDefaultValue, defaultInts, defaultFloats, defaultBools, + defaultStrings, enumValidValues, false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + // ============================================================================ // Empty-row (0 rows) handling // ============================================================================ @@ -249,6 +3687,45 @@ void testNestedMessageOutputShape() { } } + @Test + void testDeepNesting9Levels() { + verifyDeepNesting(9); + } + + @Test + void testDeepNesting10Levels() { + verifyDeepNesting(10); + } + + @Test + void testZeroLengthNestedMessage() { + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(0))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1, 1}, + new int[]{-1, 0}, + new int[]{0, 1}, + new int[]{WT_LEN, WT_VARINT}, + new int[]{DType.STRUCT.getTypeId().getNativeId(), DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT, Protobuf.ENC_DEFAULT}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new boolean[]{false, false}, + new long[]{0, 0}, + new double[]{0.0, 0.0}, + new boolean[]{false, false}, + new byte[][]{null, null}, + new int[][]{null, null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + @Test void testRepeatedFieldOutputShape() { // Schema: message Msg { repeated int32 values = 1; } @@ -281,6 +3758,35 @@ void testRepeatedFieldOutputShape() { } } + @Test + void testEmptyPackedRepeated() { + Byte[] row = concat( + box(tag(1, WT_LEN)), + box(encodeVarint(0))); + + try (Table input = new Table.TestBuilder().column(new Byte[][]{row}).build(); + ColumnVector result = decodeRaw( + input.getColumn(0), + new int[]{1}, + new int[]{-1}, + new int[]{0}, + new int[]{WT_VARINT}, + new int[]{DType.INT32.getTypeId().getNativeId()}, + new int[]{Protobuf.ENC_DEFAULT}, + new boolean[]{true}, + new boolean[]{false}, + new boolean[]{false}, + new long[]{0}, + new double[]{0.0}, + new boolean[]{false}, + new byte[][]{null}, + new int[][]{null}, + false)) { + assertNotNull(result); + assertEquals(DType.STRUCT, result.getType()); + } + } + @Test void testZeroRowNestedSchemaShape() { // 0 rows with nested schema — verify correct type hierarchy