Skip to content

Commit 63455f2

Browse files
authored
Add trt decoder (#307)
## Add TensorRT Decoder Plugin for Quantum Error Correction ### Overview This PR introduces a new TensorRT-based decoder plugin for quantum error correction, leveraging NVIDIA TensorRT for accelerated neural network inference in QEC applications. ### Key Features - **TensorRT Integration**: Full TensorRT runtime integration with support for both ONNX model loading and pre-built engine loading - **Flexible Precision Support**: Configurable precision modes (fp16, bf16, int8, fp8, tf32, best) with automatic hardware capability detection - **Memory Management**: Efficient CUDA memory allocation and stream-based execution - **Parameter Validation**: Comprehensive input validation with clear error messages - **Python Utilities**: ONNX to TensorRT engine conversion script for model preprocessing ### Technical Implementation - **Core Decoder Class**: `trt_decoder` implementing the `decoder` interface with TensorRT backend - **Hardware Detection**: Automatic GPU capability detection for optimal precision selection - **Error Handling**: Robust error handling with graceful fallbacks and informative error messages - **Plugin Architecture**: CMake-based plugin system with conditional TensorRT linking ### Files Added/Modified - `libs/qec/include/cudaq/qec/trt_decoder_internal.h` - Internal API declarations - `libs/qec/lib/decoders/plugins/trt_decoder/trt_decoder.cpp` - Main decoder implementation - `libs/qec/lib/decoders/plugins/trt_decoder/CMakeLists.txt` - Plugin build configuration - `libs/qec/python/cudaq_qec/plugins/tensorrt_utils/build_engine_from_onnx.py` - Python utility - `libs/qec/unittests/test_trt_decoder.cpp` - Comprehensive unit tests - Updated CMakeLists.txt files for integration ### Testing - ✅ All 8 unit tests passing - Parameter validation tests - File loading utility tests - Edge case handling tests - Error condition testing ### Usage Example ```cpp // Load from ONNX model cudaqx::heterogeneous_map params; params.insert("onnx_load_path", "model.onnx"); params.insert("precision", "fp16"); auto decoder = std::make_unique<trt_decoder>(H, params); // Or load pre-built engine params.clear(); params.insert("engine_load_path", "model.trt"); auto decoder = std::make_unique<trt_decoder>(H, params); ``` ### Dependencies - TensorRT 10.13.3.9+ - CUDA 12.0+ - NVIDIA GPU with appropriate compute capability ### Performance Benefits - GPU-accelerated inference for QEC decoding - Optimized precision selection based on hardware capabilities - Efficient memory usage with CUDA streams - Reduced latency compared to CPU-based decoders This implementation provides a production-ready TensorRT decoder plugin that can significantly accelerate quantum error correction workflows while maintaining compatibility with the existing CUDA-Q QEC framework. --------- Signed-off-by: Scott Thornton <[email protected]>
1 parent 55f79c0 commit 63455f2

File tree

21 files changed

+2526
-15
lines changed

21 files changed

+2526
-15
lines changed

.gitattributes

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.onnx filter=lfs diff=lfs merge=lfs -text

.github/workflows/all_libs.yaml

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,21 @@ jobs:
6464

6565
- name: Install build requirements
6666
run: |
67-
apt install -y --no-install-recommends gfortran libblas-dev
67+
apt install -y --no-install-recommends gfortran libblas-dev wget
68+
69+
- name: Install TensorRT (amd64)
70+
if: matrix.platform == 'amd64'
71+
run: |
72+
apt-cache search tensorrt | awk '{print "Package: "$1"\nPin: version *+cuda${{ matrix.cuda_version }}\nPin-Priority: 1001\n"}' | tee /etc/apt/preferences.d/tensorrt-cuda${{ matrix.cuda_version }}.pref > /dev/null
73+
apt update
74+
apt install -y tensorrt-dev
75+
76+
- name: Install TensorRT (arm64)
77+
if: matrix.platform == 'arm64'
78+
run: |
79+
apt-cache search tensorrt | awk '{print "Package: "$1"\nPin: version *+cuda13.0\nPin-Priority: 1001\n"}' | tee /etc/apt/preferences.d/tensorrt-cuda13.0.pref > /dev/null
80+
apt update
81+
apt install -y tensorrt-dev
6882
6983
- name: Build
7084
id: build
@@ -92,7 +106,7 @@ jobs:
92106
LD_LIBRARY_PATH: ${{ env.MPI_PATH }}/lib:${{ env.LD_LIBRARY_PATH }}
93107
shell: bash
94108
run: |
95-
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} torch lightning ml_collections mpi4py transformers quimb opt_einsum torch nvidia-cublas-cu${{ steps.config.outputs.cuda_major }} cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09
109+
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} torch lightning ml_collections mpi4py transformers quimb opt_einsum torch nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09
96110
# The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py.
97111
if [ "$(uname -m)" == "x86_64" ]; then
98112
# Stim is not currently available on manylinux ARM wheels, so only

.github/workflows/build_wheels.yaml

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,59 @@ jobs:
6363
build-type: Release
6464

6565
steps:
66-
- name: Get code
67-
uses: actions/checkout@v4
68-
with:
69-
set-safe-directory: true
70-
7166
- name: Configure
7267
id: config
7368
run: |
7469
cuda_major=`echo ${{ matrix.cuda_version }} | cut -d . -f1`
7570
echo "cuda_major=$cuda_major" >> $GITHUB_OUTPUT
71+
# Map CUDA 12.6 to 12.9 for TensorRT filename
72+
if [ "${{ matrix.cuda_version }}" == "12.6" ]; then
73+
tensorrt_cuda_version="12.9"
74+
tensorrt_cuda_major="12"
75+
else
76+
tensorrt_cuda_version="${{ matrix.cuda_version }}"
77+
tensorrt_cuda_major="$cuda_major"
78+
fi
79+
echo "tensorrt_cuda_version=$tensorrt_cuda_version" >> $GITHUB_OUTPUT
80+
echo "tensorrt_cuda_major=$tensorrt_cuda_major" >> $GITHUB_OUTPUT
81+
tensorrt_major_version="10.13.3"
82+
tensorrt_minor_version="9"
83+
tensorrt_version="${tensorrt_major_version}.${tensorrt_minor_version}"
84+
echo "tensorrt_major_version=$tensorrt_major_version" >> $GITHUB_OUTPUT
85+
echo "tensorrt_version=$tensorrt_version" >> $GITHUB_OUTPUT
86+
87+
- name: Install TensorRT (amd64)
88+
shell: bash
89+
if: matrix.platform == 'amd64'
90+
run: |
91+
mkdir -p /trt_download
92+
pushd /trt_download
93+
pwd
94+
wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${{ steps.config.outputs.tensorrt_major_version }}/tars/TensorRT-${{ steps.config.outputs.tensorrt_version }}.Linux.x86_64-gnu.cuda-${{ steps.config.outputs.tensorrt_cuda_version }}.tar.gz
95+
tar -zxvf TensorRT-${{ steps.config.outputs.tensorrt_version }}.Linux.x86_64-gnu.cuda-${{ steps.config.outputs.tensorrt_cuda_version }}.tar.gz
96+
pwd
97+
popd
98+
find /trt_download/TensorRT-${{ steps.config.outputs.tensorrt_version }} -name "NvInfer.h"
99+
find /trt_download/TensorRT-${{ steps.config.outputs.tensorrt_version }} -name "NvInferRuntime.h"
100+
101+
- name: Install TensorRT (arm64)
102+
shell: bash
103+
if: matrix.platform == 'arm64'
104+
run: |
105+
mkdir -p /trt_download
106+
pushd /trt_download
107+
pwd
108+
wget https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/${{ steps.config.outputs.tensorrt_major_version }}/tars/TensorRT-${{ steps.config.outputs.tensorrt_version }}.Linux.aarch64-gnu.cuda-13.0.tar.gz
109+
tar -zxvf TensorRT-${{ steps.config.outputs.tensorrt_version }}.Linux.aarch64-gnu.cuda-13.0.tar.gz
110+
pwd
111+
popd
112+
find /trt_download/TensorRT-${{ steps.config.outputs.tensorrt_version }} -name "NvInfer.h"
113+
find /trt_download/TensorRT-${{ steps.config.outputs.tensorrt_version }} -name "NvInferRuntime.h"
114+
115+
- name: Get code
116+
uses: actions/checkout@v4
117+
with:
118+
set-safe-directory: true
76119

77120
# Do this early to help validate user inputs (if present)
78121
- name: Fetch assets
@@ -123,6 +166,7 @@ jobs:
123166
--cudaq-prefix /usr/local/cudaq \
124167
--build-type ${{ inputs.build_type }} \
125168
--python-version ${{ matrix.python }} \
169+
--tensorrt-path /trt_download/TensorRT-${{ steps.config.outputs.tensorrt_version }} \
126170
--version ${{ inputs.version || '0.99.99' }}
127171
128172
- name: Upload artifact
@@ -332,11 +376,19 @@ jobs:
332376
cuda_version: ['12.6', '13.0']
333377

334378
steps:
379+
380+
- name: Install git for LFS
381+
shell: bash
382+
run: |
383+
apt update
384+
apt install -y --no-install-recommends git git-lfs
385+
335386
- name: Get code
336387
uses: actions/checkout@v4
337388
with:
338389
set-safe-directory: true
339-
390+
lfs: true # download assets file(s) for TRT tests
391+
340392
- name: Configure
341393
id: config
342394
run: |

.github/workflows/lib_qec.yaml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,23 @@ jobs:
6161
# ========================================================================
6262
# Build library
6363
# ========================================================================
64+
- name: Install build requirements
65+
run: |
66+
apt install -y --no-install-recommends gfortran libblas-dev wget
67+
68+
- name: Install TensorRT (amd64)
69+
if: matrix.platform == 'amd64'
70+
run: |
71+
apt-cache search tensorrt | awk '{print "Package: "$1"\nPin: version *+cuda${{ matrix.cuda_version }}\nPin-Priority: 1001\n"}' | tee /etc/apt/preferences.d/tensorrt-cuda${{ matrix.cuda_version }}.pref > /dev/null
72+
apt update
73+
apt install -y tensorrt-dev
74+
75+
- name: Install TensorRT (arm64)
76+
if: matrix.platform == 'arm64'
77+
run: |
78+
apt-cache search tensorrt | awk '{print "Package: "$1"\nPin: version *+cuda13.0\nPin-Priority: 1001\n"}' | tee /etc/apt/preferences.d/tensorrt-cuda13.0.pref > /dev/null
79+
apt update
80+
apt install -y tensorrt-dev
6481
6582
- name: Build
6683
id: build
@@ -86,7 +103,7 @@ jobs:
86103
- name: Install python requirements
87104
shell: bash
88105
run: |
89-
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum torch nvidia-cublas-cu${{ steps.config.outputs.cuda_major }} cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09
106+
pip install numpy pytest cupy-cuda${{ steps.config.outputs.cuda_major }}x cuquantum-cu${{ steps.config.outputs.cuda_major }} quimb opt_einsum torch nvidia-cublas cuquantum-python-cu${{ steps.config.outputs.cuda_major }}==25.09
90107
# The following tests are needed for docs/sphinx/examples/qec/python/tensor_network_decoder.py.
91108
if [ "$(uname -m)" == "x86_64" ]; then
92109
# Stim is not currently available on manylinux ARM wheels, so only

.github/workflows/scripts/build_wheels.sh

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ show_help() {
2222
echo " --cudaq-prefix Path to CUDA-Q's install prefix"
2323
echo " (default: \$HOME/.cudaq)"
2424
echo " --python-version Python version to build wheel for (e.g. 3.11)"
25+
echo " --tensorrt-path Path to TensorRT installation directory"
26+
echo " (default: /trt_download/TensorRT-10.13.3.9)"
2527
echo " --devdeps Build wheels suitable for internal testing"
2628
echo " (not suitable for distribution but sometimes"
2729
echo " helpful for debugging)"
@@ -68,6 +70,15 @@ parse_options() {
6870
exit 1
6971
fi
7072
;;
73+
--tensorrt-path)
74+
if [[ -n "$2" && "$2" != -* ]]; then
75+
tensorrt_path=("$2")
76+
shift 2
77+
else
78+
echo "Error: Argument for $1 is missing" >&2
79+
exit 1
80+
fi
81+
;;
7182
--devdeps)
7283
devdeps=true
7384
shift 1
@@ -99,6 +110,7 @@ parse_options() {
99110
cudaq_prefix=$HOME/.cudaq
100111
build_type=Release
101112
python_version=3.11
113+
tensorrt_path=/trt_download/TensorRT-10.13.3.9
102114
devdeps=false
103115
wheels_version=0.0.0
104116
cuda_version=12
@@ -136,7 +148,7 @@ export CUDAQX_SOLVERS_VERSION=$wheels_version
136148
cd libs/qec
137149
cp pyproject.toml.cu${cuda_version} pyproject.toml
138150

139-
SKBUILD_CMAKE_ARGS="-DCUDAQ_DIR=$cudaq_prefix/lib/cmake/cudaq"
151+
SKBUILD_CMAKE_ARGS="-DCUDAQ_DIR=$cudaq_prefix/lib/cmake/cudaq;-DTENSORRT_ROOT=$tensorrt_path"
140152
if ! $devdeps; then
141153
SKBUILD_CMAKE_ARGS+=";-DCMAKE_CXX_COMPILER_EXTERNAL_TOOLCHAIN=/opt/rh/gcc-toolset-11/root/usr/lib/gcc/${ARCH}-redhat-linux/11/"
142154
fi
@@ -146,9 +158,12 @@ $python -m build --wheel
146158

147159
CUDAQ_EXCLUDE_LIST=$(for f in $(find $cudaq_prefix/lib -name "*.so" -printf "%P\n" | sort); do echo "--exclude $f"; done | tr '\n' ' ')
148160

149-
LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$(pwd)/_skbuild/lib" \
161+
LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$(pwd)/_skbuild/lib:$tensorrt_path/lib" \
150162
$python -m auditwheel -v repair dist/*.whl $CUDAQ_EXCLUDE_LIST \
151163
--wheel-dir /wheels \
164+
--exclude libcudart.so.${cuda_version} \
165+
--exclude libnvinfer.so.10 \
166+
--exclude libnvonnxparser.so.10 \
152167
${PLAT_STR}
153168

154169
# ==============================================================================
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:f739c30d81b9fcdf668db6ba4211fc5616c8b462b41fdae07ec2d6e0c069064c
3+
size 191673
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*******************************************************************************
2+
* Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. *
3+
* All rights reserved. *
4+
* *
5+
* This source code and the accompanying materials are made available under *
6+
* the terms of the Apache License 2.0 which accompanies this distribution. *
7+
******************************************************************************/
8+
9+
#pragma once
10+
11+
#include "cudaq/qec/decoder.h"
12+
#include <memory>
13+
#include <string>
14+
#include <vector>
15+
16+
#include "NvInfer.h"
17+
#include "NvOnnxParser.h"
18+
19+
namespace cudaq::qec::trt_decoder_internal {
20+
21+
/// @brief Validates TRT decoder parameters
22+
/// @param params The parameter map to validate
23+
/// @throws std::runtime_error if parameters are invalid
24+
void validate_trt_decoder_parameters(const cudaqx::heterogeneous_map &params);
25+
26+
/// @brief Loads a binary file into memory
27+
/// @param filename Path to the file to load
28+
/// @return Vector containing the file contents
29+
/// @throws std::runtime_error if file cannot be opened
30+
std::vector<char> load_file(const std::string &filename);
31+
32+
/// @brief Builds a TensorRT engine from an ONNX model
33+
/// @param onnx_model_path Path to the ONNX model file
34+
/// @param params Configuration parameters
35+
/// @param logger TensorRT logger instance
36+
/// @return Unique pointer to the built TensorRT engine
37+
/// @throws std::runtime_error if engine building fails
38+
std::unique_ptr<nvinfer1::ICudaEngine>
39+
build_engine_from_onnx(const std::string &onnx_model_path,
40+
const cudaqx::heterogeneous_map &params,
41+
nvinfer1::ILogger &logger);
42+
43+
/// @brief Saves a TensorRT engine to a file
44+
/// @param engine The engine to save
45+
/// @param file_path Path where to save the engine
46+
/// @throws std::runtime_error if saving fails
47+
void save_engine_to_file(nvinfer1::ICudaEngine *engine,
48+
const std::string &file_path);
49+
50+
/// @brief Parses and configures precision settings for TensorRT
51+
/// @param precision The precision string (fp16, bf16, int8, fp8, noTF32, best)
52+
/// @param config TensorRT builder config instance
53+
void parse_precision(const std::string &precision,
54+
nvinfer1::IBuilderConfig *config);
55+
56+
} // namespace cudaq::qec::trt_decoder_internal

libs/qec/lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ add_library(${LIBRARY_NAME} SHARED
2525
)
2626

2727
add_subdirectory(decoders/plugins/example)
28+
add_subdirectory(decoders/plugins/trt_decoder)
2829
add_subdirectory(codes)
2930
add_subdirectory(device)
3031

0 commit comments

Comments
 (0)