Add a protocol buffer decode kernel#4107
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a GPU-accelerated protocol buffer decoder with intentionally limited features, focusing on simple scalar field types. The implementation provides a JNI interface for decoding binary protobuf messages into cuDF STRUCT columns.
Key changes:
- Implements GPU kernels for decoding protobuf varint, fixed32/64, and length-delimited (string) fields
- Adds JNI bindings between Java and CUDA implementation
- Provides basic test coverage for INT64 and STRING field types
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
| src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java | Java API providing decodeToStruct() method with parameter validation |
| src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java | Basic test case covering varint (INT64) and string decoding with missing fields and null messages |
| src/main/cpp/src/protobuf_simple.hpp | C++ API declaration with documentation of supported types |
| src/main/cpp/src/protobuf_simple.cu | CUDA implementation with three specialized kernels for varint, fixed-width, and string extraction |
| src/main/cpp/src/ProtobufSimpleJni.cpp | JNI bridge translating Java arrays to C++ vectors and invoking decode logic |
| src/main/cpp/CMakeLists.txt | Build configuration adding new source files to compilation targets |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java
Outdated
Show resolved
Hide resolved
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java
Outdated
Show resolved
Hide resolved
|
@greptile full review |
Greptile SummaryThis PR delivers a comprehensive GPU-accelerated protobuf decoder (~12,000 lines of new CUDA/C++/Java) that converts The PR has undergone extensive review iteration; all previously reported correctness issues have been addressed:
One P1 build failure remains: Two minor P2 notes: Confidence Score: 4/5Safe to merge for the core library and JNI/Java code; the benchmark build target will fail until All previously reported correctness bugs have been resolved — including the byte-granular memset for ERR_BOUNDS init, sfixed default value branch, repeated enum validation, varint error propagation, fast-path depth checks, and JNI local-ref leaks. The remaining P1 is limited to the benchmark file, which references a deleted header and stale namespaces and will fail to compile. Since the benchmark is explicitly flagged as 'to be submitted separately' in the PR description but is wired into CMakeLists.txt, it will break the benchmark build target. Fixing or removing the benchmark file brings confidence to 5/5. src/main/cpp/benchmarks/protobuf_decode.cu — includes Important Files Changed
Sequence DiagramsequenceDiagram
participant Java as Protobuf.java
participant JNI as ProtobufJni.cpp
participant Entry as protobuf.cu
participant Scan as protobuf_kernels.cu
participant Extract as protobuf_kernels.cuh
participant Build as protobuf_builders.cu
Java->>JNI: decodeToStruct(binaryInput, schema, failOnErrors)
JNI->>JNI: validate arrays, convert to host_vectors
JNI->>Entry: decode_protobuf_to_struct(input, context, stream, mr)
Entry->>Entry: validate_decode_context()
Entry->>Scan: count_repeated_fields_kernel (Pass 1: count + locate)
Scan-->>Entry: d_repeated_info, d_nested_locations
Entry->>Entry: thrust::exclusive_scan (Pass 2: prefix-sum offsets)
Entry->>Scan: scan_all_fields_kernel (Pass 3a: scalar field locations)
Entry->>Scan: scan_all_repeated_occurrences_kernel (Pass 3b: repeated locations)
Entry->>Extract: extract_varint/fixed/lengths_batched_kernel (Pass 4: parallel extraction)
Entry->>Build: build_nested_struct_column (Pass 5: recursive assembly)
Build->>Scan: scan_nested_message_fields_kernel
Build->>Scan: count/scan_repeated_in_nested_kernel
Build-->>Entry: nested cudf::column
Entry->>Entry: check_required_fields, validate_enum_values
Entry->>Entry: PERMISSIVE null propagation
Entry-->>JNI: std::unique_ptr cudf::column (STRUCT)
JNI-->>Java: ColumnVector handle
|
| Benchmark uses | Correct qualified name |
|---|---|
spark_rapids_jni::nested_field_descriptor |
spark_rapids_jni::protobuf::nested_field_descriptor |
spark_rapids_jni::protobuf_detail::field_location |
spark_rapids_jni::protobuf::detail::field_location |
spark_rapids_jni::protobuf_detail::repeated_occurrence |
spark_rapids_jni::protobuf::detail::repeated_occurrence |
spark_rapids_jni::wire_type_value(...) |
spark_rapids_jni::protobuf::wire_type_value(...) |
spark_rapids_jni::proto_wire_type::VARINT |
spark_rapids_jni::protobuf::proto_wire_type::VARINT |
spark_rapids_jni::protobuf_decode_context |
spark_rapids_jni::protobuf::protobuf_decode_context |
The protobuf_common.cuh include should be replaced with the actual split headers, and all spark_rapids_jni:: type references updated to include the ::protobuf:: or ::protobuf::detail:: component.
Since the PR description notes benchmarks will be submitted separately (appended to PR 4), consider either removing this file from the current PR or adding a #error guard so CI makes the mismatch visible rather than failing silently.
Reviews (54): Last reviewed commit: "apply refactor on headers" | Re-trigger Greptile
There was a problem hiding this comment.
Additional Comments (6)
-
src/main/cpp/src/protobuf_simple.cu, line 89-91 (link)logic: potential overflow:
len64can be up to 2^64-1, but casting tointon line 90 can overflow iflen64 > INT_MAX -
src/main/cpp/src/protobuf_simple.cu, line 323-324 (link)logic: potential overflow:
len64can be larger thanINT_MAX, but casting tointon line 324 will overflow -
src/main/cpp/src/protobuf_simple.cu, line 375-376 (link)logic: race condition: multiple threads write to
*error_flagwithout atomics, causing undefined behavior when multiple threads encounter errors simultaneouslyThen in kernels, use
atomicOr(error_flag, 1)instead of*error_flag = 1 -
src/main/cpp/src/protobuf_simple.cu, line 398-407 (link)logic: protobuf uses zigzag encoding for signed integers (sint32/sint64), but varint decoding here treats them as unsigned - decoding negative values will produce incorrect results. Are you only supporting unsigned int32/int64, or should zigzag decoding be implemented for signed types?
-
src/main/cpp/src/protobuf_simple.cu, line 240 (link)syntax: type punning through
reinterpret_castof incompatible pointer types is undefined behavior in C++ -
src/main/cpp/src/protobuf_simple.cu, line 248 (link)syntax: type punning through
reinterpret_castof incompatible pointer types is undefined behavior in C++
6 files reviewed, 6 comments
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
|
@greptile full review |
There was a problem hiding this comment.
Additional Comments (2)
-
src/main/cpp/src/protobuf_simple.cu, line 186 (link)logic: zigzag decode uses signed right shift which is correct for signed types, but applied to unsigned
v -
src/main/cpp/src/ProtobufSimpleJni.cpp, line 55 (link)logic:
encodingsis constructed fromn_type_scalesbut thenout_typesis also constructed usingn_type_scales[i]as the scale parameter, which would be wrong for non-decimal types where this represents encoding
6 files reviewed, 2 comments
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 13 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java
Outdated
Show resolved
Hide resolved
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 10 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java
Outdated
Show resolved
Hide resolved
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java
Outdated
Show resolved
Hide resolved
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
…ven/spark-rapids-jni into protocol_buffer_jni_dev
|
@greptileai full review |
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
|
NOTE: release/26.02 has been created from main. Please retarget your PR to release/26.02 if it should be included in the release. |
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
|
NOTE: release/26.04 has been created from main. Please retarget your PR to release/26.04 if it should be included in the release. |
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
src/main/cpp/src/protobuf.hpp
Outdated
| namespace spark_rapids_jni { | ||
|
|
||
| // Encoding constants | ||
| constexpr int ENC_DEFAULT = 0; |
There was a problem hiding this comment.
Create an enum to improve readability like:
enum class type_id : int32_t
src/main/cpp/src/protobuf.hpp
Outdated
| constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1; | ||
|
|
||
| // Wire type constants | ||
| constexpr int WT_VARINT = 0; |
| cudf::size_type num_rows, | ||
| rmm::cuda_stream_view stream, | ||
| rmm::device_async_resource_ref mr) | ||
| { |
There was a problem hiding this comment.
Add a fast checking path to avoid unnessary GPU tasks.
if (parent_null_count == 0) return;
| stream); | ||
| col.set_null_count(child_view.size() - valid_count); | ||
| } else { | ||
| auto child_mask = cudf::detail::copy_bitmask(parent_mask_ptr, 0, num_rows, stream, mr); |
There was a problem hiding this comment.
The if path considers view.offset, but else path does not, is this intended?
There was a problem hiding this comment.
Yes that's intended, added a CUDF_EXPECTS to ensure it.
words from ai:
apply_parent_mask_to_row_aligned_columntakes an owningcudf::column&, not a slicedcolumn_view. For an owning column,column::view()/mutable_view()constructs a view withoffset = 0, andchild()likewise returns the owning child column rather than a sliced child view.So in the current protobuf decoder path both branches are effectively operating on offset-0 columns.
The nullable branch still passes
child_view.offset()becauseinplace_bitmask_andis a generic view-oriented API; in this call site that value is always 0. I added theCUDF_EXPECTS(child_view.offset() == 0, ...)in the non-nullable branch to make that invariant explicit and to prevent future misuse if this helper is ever reused on non-owning/sliced inputs.`
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
|
How does the performance compare to the CPU? It is great that we have this all working. I agree that it is very large, but unless it is faster than running it on the CPU, then there is not much of a reason to do it. I am especially concerned about large protobuffers and the memory access pattern. A thread per entry is potentially very problematic, as we found for large JSON strings in other work that is similar to this. How does validation work? I didn't see anything in there about detecting malformed protobufs and returning an error. |
|
Thanks for the review @revans2
Performance is looking positive by far. We're seeing 3~5x speedup over CPU on a real customer schema (~200 field nested message with repeated structs, ~500 output columns). The customer has also confirmed cost savings in their end-to-end testing (using different schemas). I'll add more detailed benchmark numbers in the plugin pr NVIDIA/spark-rapids#14354 description soon.
Yes it could be problematic for large protobuffers, but the situation is generally better than it was with getJsonObject.
I will add some large protobuffers testing in performance testsing to see how bad it is.
Yes there are many validations in the code. We have 12 distinct error codes, which cover all Protobuf wire-format failure modes: switch (code) {
case ERR_BOUNDS: return "Protobuf decode error: message data out of bounds";
case ERR_VARINT: return "Protobuf decode error: invalid or truncated varint";
case ERR_FIELD_NUMBER: return "Protobuf decode error: invalid field number";
case ERR_WIRE_TYPE: return "Protobuf decode error: unexpected wire type";
case ERR_OVERFLOW: return "Protobuf decode error: length-delimited field overflows message";
case ERR_FIELD_SIZE: return "Protobuf decode error: invalid field size";
case ERR_SKIP: return "Protobuf decode error: unable to skip unknown field";
case ERR_FIXED_LEN:
return "Protobuf decode error: invalid fixed-width or packed field length";
case ERR_REQUIRED: return "Protobuf decode error: missing required field";
case ERR_SCHEMA_TOO_LARGE:
return "Protobuf decode error: schema exceeds maximum supported repeated fields per kernel "
"(128)";
case ERR_MISSING_ENUM_META:
return "Protobuf decode error: missing or mismatched enum metadata for enum-as-string "
"field";
case ERR_REPEATED_COUNT_MISMATCH:
return "Protobuf decode error: repeated-field count/scan mismatch";
default: return "Protobuf decode error: unknown error";And two error modes are supported, matching Spark's There are some related java unit tests: |
|
That is really great work. Thanks for the detailed analysis. |
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
This PR adds a protocol buffer decoder with a large subset of Proto2 features, to support spark expression
from_protobuf.The code is ready to me but too large to review. I'm splitting it into small parts:
Part 0: #4373
PR Split Plan
This PR is being split into a linear chain of four focused PRs to make review more manageable. Each PR is independently compilable and testable, and each subsequent PR is purely additive to the prior one (mostly inserting new code sections rather than modifying existing logic):
PR 1 — Framework + API + stub decode (~2,500 lines, ~16 tests): Establishes the full public API surface (
protobuf.hpp,Protobuf.java,ProtobufSchemaDescriptor.java), the JNI bridge (ProtobufJni.cpp), the shared infrastructure header (protobuf_common.cuh), and a stubdecode_protobuf_to_structthat validates the schema and returns a correctly-shaped STRUCT column with all-null children. The shared header is included in full so that no follow-up PR needs to modify it. Review focus: schema validation logic (both Java and C++ sides), JNI memory safety, API contract, and correct null-column type construction for nested/repeated schemas.PR 2 — Scalar type extraction (~3,700 lines, ~55 tests): Adds the core decode pipeline —
scan_all_fields_kernelfor single-pass field location recording, batched 2D extraction kernels grouped by type (varint, fixed32/64, zigzag), two-phase string/bytes construction, default value substitution, and required-field checks. Covers all scalar protobuf types (int32,int64,uint32,uint64,sint32/64,fixed32/64,sfixed32/64,float,double,bool,string,bytes). The change is a pure insertion intoprotobuf.cubetween the field classification and the assembly section — no existing code is modified.PR 3 — Repeated fields + nested messages (~4,400 lines, ~25 tests): Adds the three-phase repeated-field pipeline (count → prefix-sum offsets → scan occurrences → build LIST columns) and recursive nested-message decoding (up to 10 levels deep). Includes repeated-in-nested and repeated-in-repeated support. Again a pure insertion into the orchestrator, between the scalar section and the final assembly.
PR 4 — Enum-as-string + PERMISSIVE mode (~1,900 lines, ~40 tests): Adds enum value validation, varint-to-UTF8-name conversion with GPU-parallel lookup tables, and PERMISSIVE mode null propagation (invalid enum or malformed rows nullify the entire struct row and propagate nulls to all descendants). This is the only PR that modifies a small amount of existing code (adding
d_row_force_nulltracking to the setup section and null propagation after the assembly).Benchmarks (8 NVBench cases, ~1,400 lines) will be submitted separately or appended to PR 4.
Summary
This PR adds a GPU-accelerated protobuf decoder that converts
LIST<INT8/UINT8>columns (one serialized protobuf message per row) into nested cuDFSTRUCTcolumns via JNI. This is the native kernel layer that powersfrom_protobuf()GPU acceleration in the spark-rapids plugin.The implementation spans ~6,500 lines of new CUDA/C++/Java code and ~3,900 lines of tests, organized into a clean four-file C++ architecture plus a validated Java schema API.
Key capabilities
int32,int64,uint32,uint64,sint32/sint64(zigzag),fixed32/sfixed32/fixed64/sfixed64,float,double,bool,string,bytesArrayType(StructType)— repeated nested messages with arbitrary child fieldsPerformance characteristics
FIELD_LOOKUP_TABLE_MAX = 4096)Architecture
File structure
Dependency graph
Multi-pass decode algorithm
The decoder processes each batch of messages through multiple GPU passes:
count_repeated_fields_kernel): One thread per row. Scans message bytes to count repeated field occurrences and record nested message locations. Handles both packed and unpacked repeated encoding.thrust::exclusive_scan): Prefix sum on repeated counts to compute output array offsets.scan_all_fields_kernel+scan_all_repeated_occurrences_kernel): Records exact byte locations (offset + length) for every target field. Last-one-wins semantics for duplicate scalar fields.MAX_NESTED_STRUCT_DECODE_DEPTH = 10.Flattened schema representation
The protobuf schema is represented as parallel arrays passed through JNI. Fields are ordered in pre-order traversal (parent before children):
fieldNumbers[]parentIndices[]depthLevels[]wireTypes[]outputTypeIds[]encodings[]isRepeated[]isRequired[]hasDefaultValue[]defaultInts/Floats/Bools/Strings[]enumValidValues[][]enumNames[][][]Example for
message Outer { int32 a = 1; Inner b = 2; } message Inner { int32 x = 1; string y = 2; }:Test coverage
107 JUnit tests in
ProtobufTest.java+ 13 tests inProtobufSchemaDescriptorTest.java, organized by feature:Benchmarks
8 NVBench benchmarks in
protobuf_decode.cu:BM_protobuf_flat_scalarsBM_protobuf_nestedBM_protobuf_repeatedBM_protobuf_wide_repeated_messageBM_protobuf_repeated_child_listsBM_protobuf_repeated_child_string_count_scanBM_protobuf_repeated_child_string_buildBM_protobuf_many_repeatedReview Guide
This PR is large (~12,000 lines total) but has a well-defined layered architecture. This guide provides a recommended reading order, key areas to focus on per file, and a mental model for understanding the code.
Recommended reading order
Read bottom-up from the API surface to the kernel internals:
protobuf.hppProtobufSchemaDescriptor.javaProtobuf.javaProtobufJni.cppprotobuf_common.cuh§1: typesfield_location,device_nested_field_descriptor, etc.protobuf_common.cuh§2: device helpersread_varint,skip_field,get_field_data_location,decode_tag,lookup_fieldprotobuf_common.cuh§3: LocationProvidersTopLevelLocationProvider,NestedLocationProvider, etc. — these abstract how extraction kernels compute byte offsetsprotobuf_common.cuh§4: template kernelsextract_varint_kernel,extract_fixed_kernel,extract_lengths_kernel,copy_varlen_data_kernel, batched variantsprotobuf_common.cuh§5: template host functionsextract_typed_column,build_repeated_scalar_column,extract_and_build_string_or_bytes_column,validate_enum_and_propagate_rowsprotobuf_kernels.cu§1: scanscan_all_fields_kernel— the core single-pass field scannerprotobuf_kernels.cu§2: count/scan repeatedcount_repeated_fields_kernel,scan_all_repeated_occurrences_kernel, shared__device__helpersprotobuf_kernels.cu§3: nestedscan_nested_message_fields_kernel,scan_repeated_message_children_kernel, compute kernelsprotobuf_kernels.cu§4: validationcheck_required_fields_kernel,validate_enum_values_kernel, enum-string kernelsprotobuf_builders.cu§1: utilitiesmake_null_column,make_empty_column_safe,make_null_list_column_with_childprotobuf_builders.cu§2: enum-stringmake_enum_string_lookup_tables,build_enum_string_column,build_repeated_enum_string_columnprotobuf_builders.cu§3: nested structbuild_nested_struct_column— most complex builder, recursive depth handlingprotobuf_builders.cu§4: repeated structbuild_repeated_struct_column,build_repeated_child_list_column— repeated-in-repeatedprotobuf.cudecode_protobuf_to_struct: orchestration, batched scalar extraction, PERMISSIVE null propagationProtobufTest.javaprotobuf_decode.cuTotal estimated review time: ~5-6 hours for a thorough review.
Key review areas by priority
P0: Correctness-critical
scan_all_fields_kernel(protobuf_kernels.cu): This is the most important kernel. One thread per row, scans all bytes of each message, records field locations. Key things to verify:cur < msg_end)atomicCAS(no races)row_has_invalid_datais set on parse errors so the row can be nullifiedcount_repeated_fields_kernel(protobuf_kernels.cu): Counts repeated field occurrences. Must correctly distinguish packed vs unpacked encoding. Packed detection:wire_type == WT_LENbutexpected_wire_type != WT_LEN.build_nested_struct_column(protobuf_builders.cu): Recursive builder for nested messages. Verify:MAX_NESTED_STRUCT_DECODE_DEPTH)is_repeatedchildren inside nested messages get proper LIST wrappingPERMISSIVE mode null propagation (protobuf.cu, lines ~1163-1190): Invalid enum values must nullify the entire struct row AND propagate to all descendants via
apply_parent_mask_to_row_aligned_column+propagate_nulls_to_descendants.JNI memory safety (ProtobufJni.cpp): Every
GetObjectArrayElement/GetByteArrayElements/GetIntArrayElementsmust have a matchingDeleteLocalRef/ReleaseXxxArrayElements. Verify no leaks in the enum_names triple-nested loop.P1: Robustness
validate_decode_context(protobuf.hpp): Validates all schema invariants before any GPU work. Check completeness — duplicate field numbers under same parent, encoding compatibility, enum metadata non-empty forENC_ENUM_STRING.ProtobufSchemaDescriptor.validate(Java): Mirrors C++ validation. Defensive copies in constructor, re-validation on deserialization.Varint parsing (protobuf_common.cuh
read_varint): 10th byte must only use lowest bit. Truncated/malformed varints must returnfalse.Wire type handling (
skip_field,get_wire_type_size): VerifyWT_SGROUPuses iterative handling with depth cap of 32 (not recursive).WT_EGROUPis rejected as standalone.P2: Performance
Batched scalar extraction (protobuf.cu, lines ~400-470): Fields are grouped into 12 categories by type and extracted with 2D kernel launches. Verify grouping logic covers all type+encoding combinations.
Field lookup tables (
build_field_lookup_table,build_index_lookup_tablein protobuf_common.cuh): O(1) field_number → index mapping when max field number ≤FIELD_LOOKUP_TABLE_MAX. Falls back to linear scan otherwise.String two-phase construction:
extract_lengths_kernel→make_offsets_child_column→copy_varlen_data_kernel. Verify no off-by-one in offset calculations.Things to watch for
uvector.size()afteruvector.release()in the same expression. The code caches sizes before releasing.num_rows + 1elements.count_repeated_in_nested_kernelandscan_repeated_in_nested_kernelhandle both packed and unpacked within nested message boundaries.cudf::logic_errorfor data errors: The code usescudf::logic_errorfor wire-format errors in strict mode. This is semantically imprecise (it conventionally signals API misuse), but functionally correct.Mapping review to test coverage
scan_all_fields_kernelcount/scan_repeatedbuild_nested_struct_columnvalidate_enum + enum-stringcheck_required_fields_kernel