Skip to content

Add a protocol buffer decode kernel#4107

Open
thirtiseven wants to merge 105 commits intoNVIDIA:mainfrom
thirtiseven:protocol_buffer_jni_dev
Open

Add a protocol buffer decode kernel#4107
thirtiseven wants to merge 105 commits intoNVIDIA:mainfrom
thirtiseven:protocol_buffer_jni_dev

Conversation

@thirtiseven
Copy link
Copy Markdown
Collaborator

@thirtiseven thirtiseven commented Dec 23, 2025

This PR adds a protocol buffer decoder with a large subset of Proto2 features, to support spark expression from_protobuf.

The code is ready to me but too large to review. I'm splitting it into small parts:

Part 0: #4373

PR Split Plan

This PR is being split into a linear chain of four focused PRs to make review more manageable. Each PR is independently compilable and testable, and each subsequent PR is purely additive to the prior one (mostly inserting new code sections rather than modifying existing logic):

  1. PR 1 — Framework + API + stub decode (~2,500 lines, ~16 tests): Establishes the full public API surface (protobuf.hpp, Protobuf.java, ProtobufSchemaDescriptor.java), the JNI bridge (ProtobufJni.cpp), the shared infrastructure header (protobuf_common.cuh), and a stub decode_protobuf_to_struct that validates the schema and returns a correctly-shaped STRUCT column with all-null children. The shared header is included in full so that no follow-up PR needs to modify it. Review focus: schema validation logic (both Java and C++ sides), JNI memory safety, API contract, and correct null-column type construction for nested/repeated schemas.

  2. PR 2 — Scalar type extraction (~3,700 lines, ~55 tests): Adds the core decode pipeline — scan_all_fields_kernel for single-pass field location recording, batched 2D extraction kernels grouped by type (varint, fixed32/64, zigzag), two-phase string/bytes construction, default value substitution, and required-field checks. Covers all scalar protobuf types (int32, int64, uint32, uint64, sint32/64, fixed32/64, sfixed32/64, float, double, bool, string, bytes). The change is a pure insertion into protobuf.cu between the field classification and the assembly section — no existing code is modified.

  3. PR 3 — Repeated fields + nested messages (~4,400 lines, ~25 tests): Adds the three-phase repeated-field pipeline (count → prefix-sum offsets → scan occurrences → build LIST columns) and recursive nested-message decoding (up to 10 levels deep). Includes repeated-in-nested and repeated-in-repeated support. Again a pure insertion into the orchestrator, between the scalar section and the final assembly.

  4. PR 4 — Enum-as-string + PERMISSIVE mode (~1,900 lines, ~40 tests): Adds enum value validation, varint-to-UTF8-name conversion with GPU-parallel lookup tables, and PERMISSIVE mode null propagation (invalid enum or malformed rows nullify the entire struct row and propagate nulls to all descendants). This is the only PR that modifies a small amount of existing code (adding d_row_force_null tracking to the setup section and null propagation after the assembly).

Benchmarks (8 NVBench cases, ~1,400 lines) will be submitted separately or appended to PR 4.

Summary

This PR adds a GPU-accelerated protobuf decoder that converts LIST<INT8/UINT8> columns (one serialized protobuf message per row) into nested cuDF STRUCT columns via JNI. This is the native kernel layer that powers from_protobuf() GPU acceleration in the spark-rapids plugin.

The implementation spans ~6,500 lines of new CUDA/C++/Java code and ~3,900 lines of tests, organized into a clean four-file C++ architecture plus a validated Java schema API.

Key capabilities

  • All scalar protobuf types: int32, int64, uint32, uint64, sint32/sint64 (zigzag), fixed32/sfixed32/fixed64/sfixed64, float, double, bool, string, bytes
  • Nested messages: up to 10 levels deep, recursive decode
  • Repeated fields: both packed and unpacked encoding, auto-detected per-row
  • Repeated messages: ArrayType(StructType) — repeated nested messages with arbitrary child fields
  • Repeated-in-nested: repeated fields inside nested messages, repeated fields inside repeated messages
  • Enum-as-string: varint → validated enum → UTF-8 string name conversion, with lookup tables for GPU-parallel name resolution
  • Default values: per-field defaults for all scalar types and strings
  • Required field validation: proto2-style required field checks
  • PERMISSIVE / FAILFAST modes: configurable error handling — permissive mode nullifies malformed rows instead of throwing; invalid enum values nullify the entire struct row and propagate nulls to all descendants
  • Schema projection ready: the flattened schema representation supports decoding arbitrary subsets of fields

Performance characteristics

  • Multi-pass algorithm optimized for GPU occupancy: many simple kernels > one complex kernel
  • O(1) field-number lookup tables for scan/count kernels (up to FIELD_LOOKUP_TABLE_MAX = 4096)
  • Batched scalar extraction groups fields by type to minimize kernel launches
  • Two-phase string construction (compute lengths → prefix sum → copy) avoids pre-allocation guessing

Architecture

File structure

src/main/cpp/src/
├── protobuf.hpp            (279 lines)  Public API: types, context, validation
├── protobuf_common.cuh    (1823 lines)  Shared types, device helpers, template kernels
├── protobuf_kernels.cu    (1307 lines)  Non-template CUDA kernels
├── protobuf_builders.cu   (1719 lines)  Column builder functions
├── protobuf.cu            (1196 lines)  Entry point: decode_protobuf_to_struct
└── ProtobufJni.cpp         (278 lines)  JNI bridge

src/main/java/.../jni/
├── Protobuf.java            (116 lines)  Java public API
└── ProtobufSchemaDescriptor.java (319 lines)  Immutable schema with validation

src/test/java/.../jni/
├── ProtobufTest.java       (3565 lines)  107 decode tests
└── ProtobufSchemaDescriptorTest.java (338 lines)  13 schema validation tests

src/main/cpp/benchmarks/
└── protobuf_decode.cu      (1322 lines)  8 NVBench benchmarks

Dependency graph

Protobuf.java ──► ProtobufSchemaDescriptor.java
     │
     │ JNI
     ▼
ProtobufJni.cpp ──► protobuf.hpp
                         │
                         ▼
                    protobuf.cu (entry point, orchestration)
                         │
                         │ #include
                         ▼
                    protobuf_common.cuh (shared foundation)
                    ▲            ▲
                    │            │
          protobuf_kernels.cu  protobuf_builders.cu

Multi-pass decode algorithm

The decoder processes each batch of messages through multiple GPU passes:

  1. Count pass (count_repeated_fields_kernel): One thread per row. Scans message bytes to count repeated field occurrences and record nested message locations. Handles both packed and unpacked repeated encoding.
  2. Offset pass (thrust::exclusive_scan): Prefix sum on repeated counts to compute output array offsets.
  3. Scan pass (scan_all_fields_kernel + scan_all_repeated_occurrences_kernel): Records exact byte locations (offset + length) for every target field. Last-one-wins semantics for duplicate scalar fields.
  4. Extract pass (type-specific kernels): Parallel data extraction using pre-computed locations. Batched 2D kernel launches group fields by type (varint, fixed32, fixed64, zigzag, etc.) to minimize launch overhead.
  5. Build pass (recursive column builders): Assembles cuDF columns bottom-up. Nested structs and repeated messages are processed recursively up to MAX_NESTED_STRUCT_DECODE_DEPTH = 10.

Flattened schema representation

The protobuf schema is represented as parallel arrays passed through JNI. Fields are ordered in pre-order traversal (parent before children):

Array Description
fieldNumbers[] Protobuf field numbers
parentIndices[] Parent index in flat array (-1 for top-level)
depthLevels[] Nesting depth (0 for top-level)
wireTypes[] Expected protobuf wire type (0=varint, 1=64bit, 2=len, 5=32bit)
outputTypeIds[] cuDF type IDs for output columns
encodings[] Encoding (0=default, 1=fixed, 2=zigzag, 3=enum_string)
isRepeated[] Whether field is repeated (output becomes LIST)
isRequired[] Whether field is required (proto2)
hasDefaultValue[] Whether a default value exists
defaultInts/Floats/Bools/Strings[] Default values per field
enumValidValues[][] Sorted valid enum values per field (for binary search)
enumNames[][][] Enum name UTF-8 bytes per field (for enum-as-string)

Example for message Outer { int32 a = 1; Inner b = 2; } message Inner { int32 x = 1; string y = 2; }:

Index 0: a  (parentIdx=-1, depth=0, wireType=VARINT, type=INT32)
Index 1: b  (parentIdx=-1, depth=0, wireType=LEN,    type=STRUCT)
Index 2: x  (parentIdx=1,  depth=1, wireType=VARINT, type=INT32)
Index 3: y  (parentIdx=1,  depth=1, wireType=LEN,    type=STRING)

Test coverage

107 JUnit tests in ProtobufTest.java + 13 tests in ProtobufSchemaDescriptorTest.java, organized by feature:

Category Tests What is covered
Basic scalar types 3 INT32/64, FLOAT32/64, BOOL, STRING end-to-end
Varint & zigzag 9 Max values, zero, over-encoded zero, 10th-byte validation, zigzag min/max/negative
Wire format errors 14 Malformed varint, truncated fields (varint/string/fixed32/fixed64), partial data, wrong wire type
Unknown field skip 4 Skip varint, fixed32, fixed64, length-delimited unknowns
Last-one-wins 2 Duplicate field handling for scalars and strings
Float/double specials 2 NaN, +Inf, -Inf
Schema projection 2 Partial field decode, decode-none
Required fields 12 Present/missing in permissive/failfast, multi-row, nested required, absent parent skip
Default values 13 All scalar types, strings, empty strings, mixed defaults, multi-row
Repeated fields 10 Unpacked int32, packed double, uint32/64, packed-in-nested, packed-in-repeated-message
Nested messages 6 1-level, 3-level deep, repeated-inside-nested, repeated-in-repeated
Enum (as INT32) 8 Valid, zero, unknown, negative, multiple fields, missing
Enum-as-string 18 Valid/unknown/mixed, permissive null propagation, repeated enum, nested repeated enum, sibling field visibility
FAILFAST mode 13 All error types throw, valid data does not throw
Packed edge cases 4 Misaligned packed fixed32/64, large repeated, mixed packed/unpacked
Deep nesting 6 9-level, 10-level, zero-length nested, empty packed, large field numbers
Schema descriptor 13 Repeated+default reject, struct/list default reject, enum metadata, duplicate fields, encoding compat, depth limit, serialization roundtrip
Performance 1 Multi-field batched extraction correctness

Benchmarks

8 NVBench benchmarks in protobuf_decode.cu:

Benchmark What it stresses
BM_protobuf_flat_scalars Top-level scalar extraction throughput
BM_protobuf_nested Nested message recursive decode
BM_protobuf_repeated Top-level repeated field count/scan/extract
BM_protobuf_wide_repeated_message Wide repeated struct (many children)
BM_protobuf_repeated_child_lists Repeated-in-repeated (nested LIST)
BM_protobuf_repeated_child_string_count_scan Nested repeated string count+scan isolation
BM_protobuf_repeated_child_string_build Nested repeated string build isolation
BM_protobuf_many_repeated Many independent repeated fields

Review Guide

This PR is large (~12,000 lines total) but has a well-defined layered architecture. This guide provides a recommended reading order, key areas to focus on per file, and a mental model for understanding the code.

Recommended reading order

Read bottom-up from the API surface to the kernel internals:

Order File Focus Time estimate
1 protobuf.hpp Data structures, API contract, validation logic 15 min
2 ProtobufSchemaDescriptor.java Java-side schema validation (mirrors C++ validation) 15 min
3 Protobuf.java Public Java API, PERMISSIVE mode semantics 5 min
4 ProtobufJni.cpp JNI bridge: array conversion, local ref management 15 min
5 protobuf_common.cuh §1: types Lines 54–165: field_location, device_nested_field_descriptor, etc. 10 min
6 protobuf_common.cuh §2: device helpers Lines 167–400: read_varint, skip_field, get_field_data_location, decode_tag, lookup_field 20 min
7 protobuf_common.cuh §3: LocationProviders Lines ~400–650: TopLevelLocationProvider, NestedLocationProvider, etc. — these abstract how extraction kernels compute byte offsets 15 min
8 protobuf_common.cuh §4: template kernels extract_varint_kernel, extract_fixed_kernel, extract_lengths_kernel, copy_varlen_data_kernel, batched variants 20 min
9 protobuf_common.cuh §5: template host functions extract_typed_column, build_repeated_scalar_column, extract_and_build_string_or_bytes_column, validate_enum_and_propagate_rows 20 min
10 protobuf_kernels.cu §1: scan scan_all_fields_kernel — the core single-pass field scanner 20 min
11 protobuf_kernels.cu §2: count/scan repeated count_repeated_fields_kernel, scan_all_repeated_occurrences_kernel, shared __device__ helpers 20 min
12 protobuf_kernels.cu §3: nested scan_nested_message_fields_kernel, scan_repeated_message_children_kernel, compute kernels 20 min
13 protobuf_kernels.cu §4: validation check_required_fields_kernel, validate_enum_values_kernel, enum-string kernels 10 min
14 protobuf_builders.cu §1: utilities make_null_column, make_empty_column_safe, make_null_list_column_with_child 10 min
15 protobuf_builders.cu §2: enum-string make_enum_string_lookup_tables, build_enum_string_column, build_repeated_enum_string_column 15 min
16 protobuf_builders.cu §3: nested struct build_nested_struct_column — most complex builder, recursive depth handling 25 min
17 protobuf_builders.cu §4: repeated struct build_repeated_struct_column, build_repeated_child_list_column — repeated-in-repeated 20 min
18 protobuf.cu Entry point decode_protobuf_to_struct: orchestration, batched scalar extraction, PERMISSIVE null propagation 30 min
19 ProtobufTest.java Tests — skim by category, deep-read the tricky ones (enum-as-string, nested repeated) 30 min
20 protobuf_decode.cu Benchmarks — helper encoders and benchmark configurations 15 min

Total estimated review time: ~5-6 hours for a thorough review.

Key review areas by priority

P0: Correctness-critical

  1. scan_all_fields_kernel (protobuf_kernels.cu): This is the most important kernel. One thread per row, scans all bytes of each message, records field locations. Key things to verify:

    • Last-one-wins semantics for duplicate fields
    • Correct wire type dispatch and field skipping
    • Bounds checking at every byte read (cur < msg_end)
    • Error flag setting uses atomicCAS (no races)
    • Permissive mode: row_has_invalid_data is set on parse errors so the row can be nullified
  2. count_repeated_fields_kernel (protobuf_kernels.cu): Counts repeated field occurrences. Must correctly distinguish packed vs unpacked encoding. Packed detection: wire_type == WT_LEN but expected_wire_type != WT_LEN.

  3. build_nested_struct_column (protobuf_builders.cu): Recursive builder for nested messages. Verify:

    • Depth limit enforcement (MAX_NESTED_STRUCT_DECODE_DEPTH)
    • Correct parent-child location derivation
    • is_repeated children inside nested messages get proper LIST wrapping
    • 0-row / empty-child edge cases create correctly typed columns
  4. PERMISSIVE mode null propagation (protobuf.cu, lines ~1163-1190): Invalid enum values must nullify the entire struct row AND propagate to all descendants via apply_parent_mask_to_row_aligned_column + propagate_nulls_to_descendants.

  5. JNI memory safety (ProtobufJni.cpp): Every GetObjectArrayElement / GetByteArrayElements / GetIntArrayElements must have a matching DeleteLocalRef / ReleaseXxxArrayElements. Verify no leaks in the enum_names triple-nested loop.

P1: Robustness

  1. validate_decode_context (protobuf.hpp): Validates all schema invariants before any GPU work. Check completeness — duplicate field numbers under same parent, encoding compatibility, enum metadata non-empty for ENC_ENUM_STRING.

  2. ProtobufSchemaDescriptor.validate (Java): Mirrors C++ validation. Defensive copies in constructor, re-validation on deserialization.

  3. Varint parsing (protobuf_common.cuh read_varint): 10th byte must only use lowest bit. Truncated/malformed varints must return false.

  4. Wire type handling (skip_field, get_wire_type_size): Verify WT_SGROUP uses iterative handling with depth cap of 32 (not recursive). WT_EGROUP is rejected as standalone.

P2: Performance

  1. Batched scalar extraction (protobuf.cu, lines ~400-470): Fields are grouped into 12 categories by type and extracted with 2D kernel launches. Verify grouping logic covers all type+encoding combinations.

  2. Field lookup tables (build_field_lookup_table, build_index_lookup_table in protobuf_common.cuh): O(1) field_number → index mapping when max field number ≤ FIELD_LOOKUP_TABLE_MAX. Falls back to linear scan otherwise.

  3. String two-phase construction: extract_lengths_kernelmake_offsets_child_columncopy_varlen_data_kernel. Verify no off-by-one in offset calculations.

Things to watch for

  • Evaluation order: Never call uvector.size() after uvector.release() in the same expression. The code caches sizes before releasing.
  • Offsets column size: LIST columns require offsets with exactly num_rows + 1 elements.
  • Packed repeated in nested: count_repeated_in_nested_kernel and scan_repeated_in_nested_kernel handle both packed and unpacked within nested message boundaries.
  • Empty struct children: When building 0-row struct columns, repeated children must still be wrapped in empty LIST columns to maintain correct schema.
  • cudf::logic_error for data errors: The code uses cudf::logic_error for wire-format errors in strict mode. This is semantically imprecise (it conventionally signals API misuse), but functionally correct.

Mapping review to test coverage

If you're reviewing... Verify these test categories pass
scan_all_fields_kernel Basic scalar, varint/zigzag, wire format errors, unknown field skip, last-one-wins
count/scan_repeated Repeated fields, packed edge cases
build_nested_struct_column Nested messages, deep nesting
validate_enum + enum-string Enum (INT32), enum-as-string, PERMISSIVE mode
check_required_fields_kernel Required fields
Default value handling Default values
JNI bridge All tests (they all go through JNI)
Schema validation ProtobufSchemaDescriptorTest

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a GPU-accelerated protocol buffer decoder with intentionally limited features, focusing on simple scalar field types. The implementation provides a JNI interface for decoding binary protobuf messages into cuDF STRUCT columns.

Key changes:

  • Implements GPU kernels for decoding protobuf varint, fixed32/64, and length-delimited (string) fields
  • Adds JNI bindings between Java and CUDA implementation
  • Provides basic test coverage for INT64 and STRING field types

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
src/main/java/com/nvidia/spark/rapids/jni/ProtobufSimple.java Java API providing decodeToStruct() method with parameter validation
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSimpleTest.java Basic test case covering varint (INT64) and string decoding with missing fields and null messages
src/main/cpp/src/protobuf_simple.hpp C++ API declaration with documentation of supported types
src/main/cpp/src/protobuf_simple.cu CUDA implementation with three specialized kernels for varint, fixed-width, and string extraction
src/main/cpp/src/ProtobufSimpleJni.cpp JNI bridge translating Java arrays to C++ vectors and invoking decode logic
src/main/cpp/CMakeLists.txt Build configuration adding new source files to compilation targets

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
@thirtiseven
Copy link
Copy Markdown
Collaborator Author

@greptile full review

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Dec 23, 2025

Greptile Summary

This PR delivers a comprehensive GPU-accelerated protobuf decoder (~12,000 lines of new CUDA/C++/Java) that converts LIST<INT8/UINT8> columns of serialized protobuf messages into nested cuDF STRUCT columns, powering from_protobuf() acceleration in spark-rapids. The implementation covers all scalar types, nested messages, repeated fields, packed encoding, zigzag/fixed variants, enum-as-string conversion, default values, required-field validation, and PERMISSIVE/FAILFAST modes.

The PR has undergone extensive review iteration; all previously reported correctness issues have been addressed:

  • Fixed: cudaMemcpyAsync for error-flag init (was byte-granular cudaMemsetAsync producing 0x01010101)
  • Fixed: sfixed32/sfixed64 default-value branch now uses is_integral_v to pick default_int over default_float
  • Fixed: validate_enum_values_kernel called in the repeated enum path
  • Fixed: WT_VARINT error propagation in scan_repeated_message_children_kernel
  • Fixed: Depth-level checks added to both fast-path lookups (count_repeated_fields_kernel, scan_all_repeated_occurrences_kernel)
  • Fixed: ERR_SCHEMA_TOO_LARGE replaces the reused ERR_OVERFLOW for MAX_STACK_FIELDS guard
  • Fixed: Iterative WT_SGROUP handler with depth cap in get_wire_type_size
  • Fixed: Deep copies for defaultStrings/enumValidValues/enumNames in ProtobufSchemaDescriptor
  • Fixed: readObject re-validation on deserialization
  • Fixed: All JNI local-ref leaks on early return paths

One P1 build failure remains: protobuf_decode.cu includes protobuf_common.cuh (which was split into four separate headers and no longer exists) and uses stale flat spark_rapids_jni:: namespace prefixes for types now under spark_rapids_jni::protobuf:: and spark_rapids_jni::protobuf::detail::. Since ConfigureBench(PROTOBUF_DECODE_BENCH protobuf_decode.cu) is registered in benchmarks/CMakeLists.txt, this breaks the benchmark build target. Either remove the file or update it before merging.

Two minor P2 notes: cudf::logic_error is used for wire-format data errors (should be std::runtime_error), and ERR_OVERFLOW's message only describes one of its two triggering conditions.

Confidence Score: 4/5

Safe to merge for the core library and JNI/Java code; the benchmark build target will fail until protobuf_decode.cu is updated or removed

All previously reported correctness bugs have been resolved — including the byte-granular memset for ERR_BOUNDS init, sfixed default value branch, repeated enum validation, varint error propagation, fast-path depth checks, and JNI local-ref leaks. The remaining P1 is limited to the benchmark file, which references a deleted header and stale namespaces and will fail to compile. Since the benchmark is explicitly flagged as 'to be submitted separately' in the PR description but is wired into CMakeLists.txt, it will break the benchmark build target. Fixing or removing the benchmark file brings confidence to 5/5.

src/main/cpp/benchmarks/protobuf_decode.cu — includes protobuf_common.cuh (which no longer exists) and uses stale spark_rapids_jni:: namespace prefixes inconsistent with the new spark_rapids_jni::protobuf:: structure

Important Files Changed

Filename Overview
src/main/cpp/benchmarks/protobuf_decode.cu Benchmark file includes non-existent protobuf_common.cuh and uses wrong namespace paths that don't match the new code structure — will fail to compile
src/main/cpp/src/protobuf/protobuf_kernels.cu Core GPU kernels for scanning, counting, and extracting protobuf fields; previously reported issues (OOB stack array, depth check asymmetry, varint error propagation, WT_SGROUP recursion, ERR_SCHEMA_TOO_LARGE) all appear addressed
src/main/cpp/src/protobuf/protobuf.cu Entry-point orchestrator; cudaMemcpyAsync fix for ERR_BOUNDS init and batched scalar grouping are in place; PERMISSIVE null propagation through struct and list children looks correct
src/main/cpp/src/protobuf/protobuf_builders.cu Column builders including build_repeated_enum_string_column now correctly calls launch_validate_enum_values; default-value handling for sfixed32/sfixed64 uses is_integral_v branch
src/main/cpp/src/ProtobufJni.cpp JNI bridge; previously reported local-ref leaks on null GetByteArrayElements/GetIntArrayElements paths and ExceptionCheck early-return paths for names_arr are now patched
src/main/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptor.java Immutable schema descriptor; defensive deep copies for nested arrays, readObject re-validation, wire-type check, and enumNames/enumValidValues length-mismatch guard are all present
src/main/java/com/nvidia/spark/rapids/jni/Protobuf.java Public Java API with null checks for both binaryInput and schema; encoding/type compatibility validated by ProtobufSchemaDescriptor
src/main/cpp/src/protobuf/protobuf_device_helpers.cuh Device helper functions; WT_SGROUP now iterative with depth cap, read_varint 10th-byte check in place, overflow guard in WT_LEN inner-field size added
src/main/cpp/src/protobuf/protobuf_kernels.cuh Template extraction kernels and location providers; batched scalar kernels use is_integral_v for correct default-value dispatch between int and float fixed-width types
src/main/cpp/src/protobuf/protobuf_host_helpers.hpp Host-side lookup table builders and forward declarations; lookup tables correctly bound-check field numbers against FIELD_LOOKUP_TABLE_MAX
src/main/cpp/src/protobuf/protobuf_types.cuh Type definitions and error codes; ERR_SCHEMA_TOO_LARGE added as a distinct code (10), separating it from ERR_OVERFLOW (5)
src/main/cpp/src/protobuf/protobuf.hpp Public C++ API header; schema descriptor struct, encode/wire-type helpers, and decode context cleanly defined
src/test/java/com/nvidia/spark/rapids/jni/ProtobufTest.java 107 JUnit end-to-end tests covering all major features; comprehensive coverage of edge cases, error paths, and PERMISSIVE mode
src/test/java/com/nvidia/spark/rapids/jni/ProtobufSchemaDescriptorTest.java 13 schema validation tests covering invalid configs, serialization roundtrip, and enum metadata constraints

Sequence Diagram

sequenceDiagram
    participant Java as Protobuf.java
    participant JNI as ProtobufJni.cpp
    participant Entry as protobuf.cu
    participant Scan as protobuf_kernels.cu
    participant Extract as protobuf_kernels.cuh
    participant Build as protobuf_builders.cu

    Java->>JNI: decodeToStruct(binaryInput, schema, failOnErrors)
    JNI->>JNI: validate arrays, convert to host_vectors
    JNI->>Entry: decode_protobuf_to_struct(input, context, stream, mr)
    Entry->>Entry: validate_decode_context()
    Entry->>Scan: count_repeated_fields_kernel (Pass 1: count + locate)
    Scan-->>Entry: d_repeated_info, d_nested_locations
    Entry->>Entry: thrust::exclusive_scan (Pass 2: prefix-sum offsets)
    Entry->>Scan: scan_all_fields_kernel (Pass 3a: scalar field locations)
    Entry->>Scan: scan_all_repeated_occurrences_kernel (Pass 3b: repeated locations)
    Entry->>Extract: extract_varint/fixed/lengths_batched_kernel (Pass 4: parallel extraction)
    Entry->>Build: build_nested_struct_column (Pass 5: recursive assembly)
    Build->>Scan: scan_nested_message_fields_kernel
    Build->>Scan: count/scan_repeated_in_nested_kernel
    Build-->>Entry: nested cudf::column
    Entry->>Entry: check_required_fields, validate_enum_values
    Entry->>Entry: PERMISSIVE null propagation
    Entry-->>JNI: std::unique_ptr cudf::column (STRUCT)
    JNI-->>Java: ColumnVector handle
Loading

Comments Outside Diff (1)

  1. src/main/cpp/benchmarks/protobuf_decode.cu, line 27-28 (link)

    Benchmark will not compile — stale header and wrong namespaces

    protobuf_common.cuh does not exist in this PR. The new file structure splits it into protobuf_types.cuh, protobuf_device_helpers.cuh, protobuf_host_helpers.hpp, and protobuf_kernels.cuh. The benchmark also uses flat spark_rapids_jni:: prefixes for types that are now in the nested spark_rapids_jni::protobuf:: and spark_rapids_jni::protobuf::detail:: namespaces — for example:

    Benchmark uses Correct qualified name
    spark_rapids_jni::nested_field_descriptor spark_rapids_jni::protobuf::nested_field_descriptor
    spark_rapids_jni::protobuf_detail::field_location spark_rapids_jni::protobuf::detail::field_location
    spark_rapids_jni::protobuf_detail::repeated_occurrence spark_rapids_jni::protobuf::detail::repeated_occurrence
    spark_rapids_jni::wire_type_value(...) spark_rapids_jni::protobuf::wire_type_value(...)
    spark_rapids_jni::proto_wire_type::VARINT spark_rapids_jni::protobuf::proto_wire_type::VARINT
    spark_rapids_jni::protobuf_decode_context spark_rapids_jni::protobuf::protobuf_decode_context

    The protobuf_common.cuh include should be replaced with the actual split headers, and all spark_rapids_jni:: type references updated to include the ::protobuf:: or ::protobuf::detail:: component.

    Since the PR description notes benchmarks will be submitted separately (appended to PR 4), consider either removing this file from the current PR or adding a #error guard so CI makes the mismatch visible rather than failing silently.

Reviews (54): Last reviewed commit: "apply refactor on headers" | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. src/main/cpp/src/protobuf_simple.cu, line 89-91 (link)

    logic: potential overflow: len64 can be up to 2^64-1, but casting to int on line 90 can overflow if len64 > INT_MAX

  2. src/main/cpp/src/protobuf_simple.cu, line 323-324 (link)

    logic: potential overflow: len64 can be larger than INT_MAX, but casting to int on line 324 will overflow

  3. src/main/cpp/src/protobuf_simple.cu, line 375-376 (link)

    logic: race condition: multiple threads write to *error_flag without atomics, causing undefined behavior when multiple threads encounter errors simultaneously

    Then in kernels, use atomicOr(error_flag, 1) instead of *error_flag = 1

  4. src/main/cpp/src/protobuf_simple.cu, line 398-407 (link)

    logic: protobuf uses zigzag encoding for signed integers (sint32/sint64), but varint decoding here treats them as unsigned - decoding negative values will produce incorrect results. Are you only supporting unsigned int32/int64, or should zigzag decoding be implemented for signed types?

  5. src/main/cpp/src/protobuf_simple.cu, line 240 (link)

    syntax: type punning through reinterpret_cast of incompatible pointer types is undefined behavior in C++

  6. src/main/cpp/src/protobuf_simple.cu, line 248 (link)

    syntax: type punning through reinterpret_cast of incompatible pointer types is undefined behavior in C++

6 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
@thirtiseven thirtiseven requested a review from Copilot December 25, 2025 03:43
@thirtiseven
Copy link
Copy Markdown
Collaborator Author

@greptile full review

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. src/main/cpp/src/protobuf_simple.cu, line 186 (link)

    logic: zigzag decode uses signed right shift which is correct for signed types, but applied to unsigned v

  2. src/main/cpp/src/ProtobufSimpleJni.cpp, line 55 (link)

    logic: encodings is constructed from n_type_scales but then out_types is also constructed using n_type_scales[i] as the scale parameter, which would be wrong for non-decimal types where this represents encoding

6 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 13 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 10 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
@thirtiseven
Copy link
Copy Markdown
Collaborator Author

@greptileai full review

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
@nvauto
Copy link
Copy Markdown
Collaborator

nvauto commented Jan 19, 2026

NOTE: release/26.02 has been created from main. Please retarget your PR to release/26.02 if it should be included in the release.

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
@nvauto
Copy link
Copy Markdown
Collaborator

nvauto commented Mar 16, 2026

NOTE: release/26.04 has been created from main. Please retarget your PR to release/26.04 if it should be included in the release.

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
namespace spark_rapids_jni {

// Encoding constants
constexpr int ENC_DEFAULT = 0;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Create an enum to improve readability like:

enum class type_id : int32_t

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

constexpr int MAX_FIELD_NUMBER = (1 << 29) - 1;

// Wire type constants
constexpr int WT_VARINT = 0;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create an enum.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

cudf::size_type num_rows,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a fast checking path to avoid unnessary GPU tasks.
if (parent_null_count == 0) return;

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

stream);
col.set_null_count(child_view.size() - valid_count);
} else {
auto child_mask = cudf::detail::copy_bitmask(parent_mask_ptr, 0, num_rows, stream, mr);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if path considers view.offset, but else path does not, is this intended?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that's intended, added a CUDF_EXPECTS to ensure it.

words from ai:

apply_parent_mask_to_row_aligned_column takes an owning cudf::column&, not a sliced column_view. For an owning column, column::view() / mutable_view() constructs a view with offset = 0, and child() likewise returns the owning child column rather than a sliced child view.

So in the current protobuf decoder path both branches are effectively operating on offset-0 columns.

The nullable branch still passes child_view.offset() because inplace_bitmask_and is a generic view-oriented API; in this call site that value is always 0. I added the CUDF_EXPECTS(child_view.offset() == 0, ...) in the non-nullable branch to make that invariant explicit and to prevent future misuse if this helper is ever reused on non-owning/sliced inputs.`

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
@revans2
Copy link
Copy Markdown
Collaborator

revans2 commented Mar 17, 2026

How does the performance compare to the CPU? It is great that we have this all working. I agree that it is very large, but unless it is faster than running it on the CPU, then there is not much of a reason to do it. I am especially concerned about large protobuffers and the memory access pattern. A thread per entry is potentially very problematic, as we found for large JSON strings in other work that is similar to this. How does validation work? I didn't see anything in there about detecting malformed protobufs and returning an error.

@thirtiseven
Copy link
Copy Markdown
Collaborator Author

Thanks for the review @revans2

How does the performance compare to the CPU? It is great that we have this all working. I agree that it is very large, but unless it is faster than running it on the CPU, then there is not much of a reason to do it.

Performance is looking positive by far. We're seeing 3~5x speedup over CPU on a real customer schema (~200 field nested message with repeated structs, ~500 output columns). The customer has also confirmed cost savings in their end-to-end testing (using different schemas).

I'll add more detailed benchmark numbers in the plugin pr NVIDIA/spark-rapids#14354 description soon.

I am especially concerned about large protobuffers and the memory access pattern. A thread per entry is potentially very problematic, as we found for large JSON strings in other work that is similar to this.

Yes it could be problematic for large protobuffers, but the situation is generally better than it was with getJsonObject.

  • Protocol buffer messages are typically smaller and more like records than tables (not like in JSON), and they are similar in size, at least for the customer.
  • Protobuf is in binary TLV format, so there is less branching and heavy string processing in the Scan and Count passes. it is basically an O(n) scan. And in the Extract and Build passes, we don't use the a-thread-per-entry pattern.

I will add some large protobuffers testing in performance testsing to see how bad it is.

How does validation work? I didn't see anything in there about detecting malformed protobufs and returning an error.

Yes there are many validations in the code. We have 12 distinct error codes, which cover all Protobuf wire-format failure modes:

    switch (code) {
      case ERR_BOUNDS: return "Protobuf decode error: message data out of bounds";
      case ERR_VARINT: return "Protobuf decode error: invalid or truncated varint";
      case ERR_FIELD_NUMBER: return "Protobuf decode error: invalid field number";
      case ERR_WIRE_TYPE: return "Protobuf decode error: unexpected wire type";
      case ERR_OVERFLOW: return "Protobuf decode error: length-delimited field overflows message";
      case ERR_FIELD_SIZE: return "Protobuf decode error: invalid field size";
      case ERR_SKIP: return "Protobuf decode error: unable to skip unknown field";
      case ERR_FIXED_LEN:
        return "Protobuf decode error: invalid fixed-width or packed field length";
      case ERR_REQUIRED: return "Protobuf decode error: missing required field";
      case ERR_SCHEMA_TOO_LARGE:
        return "Protobuf decode error: schema exceeds maximum supported repeated fields per kernel "
               "(128)";
      case ERR_MISSING_ENUM_META:
        return "Protobuf decode error: missing or mismatched enum metadata for enum-as-string "
               "field";
      case ERR_REPEATED_COUNT_MISMATCH:
        return "Protobuf decode error: repeated-field count/scan mismatch";
      default: return "Protobuf decode error: unknown error";

And two error modes are supported, matching Spark's from_protobuf semantics: FAILFAST (fail_on_errors=true) and PERMISSIVE (fail_on_errors=false)

There are some related java unit tests: testWrongWireType, testSkipUnknownVarintField , testFieldNumberZeroInvalid, etc.

@revans2
Copy link
Copy Markdown
Collaborator

revans2 commented Mar 18, 2026

That is really great work. Thanks for the detailed analysis.

Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Signed-off-by: Haoyang Li <haoyangl@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants