Skip to content

Commit 89e2090

Browse files
[mlir-tensorrt] Add TensorRT 8.6 support (#391)
This PR makes the following changes, - Make TensorRT 10.5 as default version - Add TensorRT 8.6 download support - Add TensorRT 8.6 to CI - TensorRT 9 checks from CI are removed to deal with device space error. - Fix tests to support above changes
1 parent 3a8362c commit 89e2090

File tree

10 files changed

+422
-378
lines changed

10 files changed

+422
-378
lines changed

.github/workflows/mlir-tensorrt-ci.yml

+10-10
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ jobs:
148148
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
149149
-DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \
150150
-DMLIR_TRT_ENABLE_ASSERTIONS=ON \
151-
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \
151+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \
152152
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
153153
-DMLIR_TRT_USE_LINKER=lld \
154154
-DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF
@@ -191,7 +191,7 @@ jobs:
191191
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
192192
-DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \
193193
-DMLIR_TRT_ENABLE_ASSERTIONS=ON \
194-
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \
194+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \
195195
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
196196
-DMLIR_TRT_USE_LINKER=lld \
197197
-DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF \
@@ -209,8 +209,8 @@ jobs:
209209
210210
bash build_and_test.sh
211211
212-
# Run LIT tests with TensorRT 9
213-
- name: Run MLIR-TensorRT lit tests with TensorRT 9
212+
# Run LIT tests with TensorRT 8
213+
- name: Run MLIR-TensorRT lit tests with TensorRT 8
214214
uses: addnab/docker-run-action@v3
215215
with:
216216
image: ${{ env.DEFAULT_IMAGE }}
@@ -235,7 +235,7 @@ jobs:
235235
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
236236
-DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \
237237
-DMLIR_TRT_ENABLE_ASSERTIONS=ON \
238-
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=9.2.0.5 \
238+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=8.6.1.6 \
239239
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
240240
-DMLIR_TRT_USE_LINKER=lld \
241241
-DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF
@@ -324,7 +324,7 @@ jobs:
324324
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
325325
-DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \
326326
-DMLIR_TRT_ENABLE_ASSERTIONS=ON \
327-
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \
327+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \
328328
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
329329
-DMLIR_TRT_USE_LINKER=lld \
330330
-DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF
@@ -367,7 +367,7 @@ jobs:
367367
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
368368
-DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \
369369
-DMLIR_TRT_ENABLE_ASSERTIONS=ON \
370-
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.2 \
370+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=10.5 \
371371
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
372372
-DMLIR_TRT_USE_LINKER=lld \
373373
-DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF \
@@ -385,8 +385,8 @@ jobs:
385385
386386
bash build_and_test.sh
387387
388-
# Run LIT tests with TensorRT 9
389-
- name: Run MLIR-TensorRT lit tests with TensorRT 9
388+
# Run LIT tests with TensorRT 8
389+
- name: Run MLIR-TensorRT lit tests with TensorRT 8
390390
uses: addnab/docker-run-action@v3
391391
with:
392392
image: ${{ env.DEFAULT_IMAGE }}
@@ -411,7 +411,7 @@ jobs:
411411
-DCMAKE_BUILD_TYPE=RelWithDebInfo \
412412
-DMLIR_TRT_PACKAGE_CACHE_DIR=/.cache.cpm \
413413
-DMLIR_TRT_ENABLE_ASSERTIONS=ON \
414-
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=9.2.0.5 \
414+
-DMLIR_TRT_DOWNLOAD_TENSORRT_VERSION=8.6.1.6 \
415415
-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \
416416
-DMLIR_TRT_USE_LINKER=lld \
417417
-DMLIR_EXECUTOR_ENABLE_GPU_INTEGRATION_TESTS=OFF

mlir-tensorrt/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ mtrt_option(MLIR_TRT_ENABLE_EXECUTOR "Build the Executor dialect and MLIR-Tensor
5555
mtrt_option(MLIR_TRT_ENABLE_NCCL "Enable the NCCL runtime module" ON)
5656

5757
set(MLIR_TRT_TENSORRT_DIR "" CACHE STRING "Path to TensorRT install directory")
58-
set(MLIR_TRT_DOWNLOAD_TENSORRT_VERSION "10.2" CACHE STRING
58+
set(MLIR_TRT_DOWNLOAD_TENSORRT_VERSION "10.5" CACHE STRING
5959
"Version of TensorRT to download and use. It overrides MLIR_TRT_TENSORRT_DIR.")
6060
set(MLIR_TRT_PACKAGE_CACHE_DIR "" CACHE STRING "Directory where to cache downloaded C++ packages")
6161
set(MLIR_TRT_USE_LINKER "" CACHE STRING "Specify a linker to use (e.g. LLD); this is just an alias for LLVM_USE_LINKER")

mlir-tensorrt/build_tools/cmake/Dependencies.cmake

+34-4
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,15 @@ function(download_tensorrt)
8686
if(ARG_VERSION VERSION_EQUAL "10.2")
8787
set(ARG_VERSION "10.2.0.19")
8888
endif()
89+
# Canonicalize "10.5" version by setting it to the latest public TRT 10.5 version.
90+
if(ARG_VERSION VERSION_EQUAL "10.5")
91+
set(ARG_VERSION "10.5.0.18")
92+
endif()
8993

9094
set(downloadable_versions
91-
"9.0.1.4" "9.1.0.4" "9.2.0.5"
95+
"8.6.1.6" "9.0.1.4" "9.1.0.4" "9.2.0.5"
9296
"10.0.0.6" "10.1.0.27"
93-
"10.2.0.19"
97+
"10.2.0.19" "10.5.0.18"
9498
)
9599

96100
if(NOT ARG_VERSION IN_LIST downloadable_versions)
@@ -100,6 +104,28 @@ function(download_tensorrt)
100104

101105
set(TRT_VERSION "${ARG_VERSION}")
102106

107+
# Handle TensorRT 8 versions. These are publicly accessible download links.
108+
if(ARG_VERSION VERSION_LESS 9.0.0 AND ARG_VERSION VERSION_GREATER 8.0.0)
109+
string(REGEX MATCH "[0-9]+\\.[0-9]+\\.[0-9]+" trt_short_version ${ARG_VERSION})
110+
set(CUDA_VERSION "12.0")
111+
set(OS "linux")
112+
EXECUTE_PROCESS(COMMAND uname -m
113+
COMMAND tr -d '\n'
114+
OUTPUT_VARIABLE ARCH)
115+
if(ARCH STREQUAL "arm64")
116+
set(ARCH "aarch64")
117+
set(OS "Ubuntu-20.04")
118+
elseif(ARCH STREQUAL "amd64")
119+
set(ARCH "x86_64")
120+
set(OS "Linux")
121+
elseif(ARCH STREQUAL "aarch64")
122+
set(OS "Ubuntu-20.04")
123+
elseif(NOT (ARCH STREQUAL "x86_64"))
124+
message(FATAL_ERROR "Direct download not available for architecture: ${ARCH}")
125+
endif()
126+
set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/secure/${trt_short_version}/tars/TensorRT-${TRT_VERSION}.${OS}.${ARCH}-gnu.cuda-${CUDA_VERSION}.tar.gz")
127+
endif()
128+
103129
# Handle TensorRT 9 versions. These are publicly accessible download links.
104130
if(ARG_VERSION VERSION_LESS 10.0.0 AND ARG_VERSION VERSION_GREATER 9.0.0)
105131
string(REGEX MATCH "[0-9]+\\.[0-9]+\\.[0-9]+" trt_short_version ${ARG_VERSION})
@@ -137,19 +163,23 @@ function(download_tensorrt)
137163
set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.2.0/tars/TensorRT-10.2.0.19.Linux.x86_64-gnu.cuda-12.5.tar.gz")
138164
endif()
139165

166+
if(ARG_VERSION VERSION_EQUAL 10.5.0.18)
167+
set(_url "https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.5.0/tars/TensorRT-10.5.0.18.Linux.x86_64-gnu.cuda-12.6.tar.gz")
168+
endif()
169+
140170
if(NOT _url)
141171
message(FATAL_ERROR "Could not determine TensorRT download URL")
142172
endif()
143173

144174
message(STATUS "TensorRT Download URL: ${_url}")
145175

146176
CPMAddPackage(
147-
NAME TensorRT9
177+
NAME TensorRT
148178
VERSION "${TRT_VERSION}"
149179
URL ${_url}
150180
DOWNLOAD_ONLY
151181
)
152-
set("${ARG_OUT_VAR}" "${TensorRT9_SOURCE_DIR}" PARENT_SCOPE)
182+
set("${ARG_OUT_VAR}" "${TensorRT_SOURCE_DIR}" PARENT_SCOPE)
153183
endfunction()
154184

155185
#-------------------------------------------------------------------------------------

mlir-tensorrt/compiler/lib/Conversion/StablehloToTensorRT/StablehloToTensorRT.cpp

+26-12
Original file line numberDiff line numberDiff line change
@@ -2818,10 +2818,15 @@ struct PadConverter : public ConvertHloOpToTensorRTPattern<stablehlo::PadOp> {
28182818
auto padLowHighSum = trtRewriter.checkAndCreate<tensorrt::ElementWiseOp>(
28192819
loc, targetTrtMajorVersion, shapeTensorType, padLowConst, padHighConst,
28202820
tensorrt::ElementWiseOperation::kSUM);
2821+
if (!padLowHighSum)
2822+
return failure();
28212823
Value size = padLowHighSum.getResult();
2822-
size = trtRewriter.checkAndCreate<tensorrt::ElementWiseOp>(
2824+
auto sumWithResult = trtRewriter.checkAndCreate<tensorrt::ElementWiseOp>(
28232825
loc, targetTrtMajorVersion, shapeTensorType, size, shape.getResult(),
28242826
tensorrt::ElementWiseOperation::kSUM);
2827+
if (!sumWithResult)
2828+
return failure();
2829+
size = sumWithResult.getResult();
28252830

28262831
SmallVector<int32_t> stride(inputType.getRank(), 1);
28272832
return trtRewriter.checkAndReplaceOpWithNewOp<tensorrt::SliceOp>(
@@ -3858,7 +3863,7 @@ struct ConvertScatterToTensorRTScatterElements
38583863
if (!constOneTuple)
38593864
return failure();
38603865

3861-
Value newIndices = trtRewriter.checkAndCreate<tensorrt::LinspaceOp>(
3866+
auto newIndices = trtRewriter.checkAndCreate<tensorrt::LinspaceOp>(
38623867
op->getLoc(), targetTrtMajorVersion,
38633868
newUpdateType.clone(rewriter.getI32Type()), Value(), startIndex,
38643869
constOneTuple, FloatAttr(), FloatAttr());
@@ -3884,7 +3889,7 @@ struct ConvertScatterToTensorRTScatterElements
38843889
auto newOp = trtRewriter.checkAndCreate<tensorrt::ScatterElementsOp>(
38853890
op->getLoc(), targetTrtMajorVersion,
38863891
/*data*/ convertToI32(adaptor.getInputs().front()),
3887-
/*indices*/ newIndices,
3892+
/*indices*/ newIndices.getResult(),
38883893
/*updates*/ convertToI32(newUpdates),
38893894
/*axis*/ rewriter.getI64IntegerAttr(axis));
38903895
if (!newOp)
@@ -3894,7 +3899,8 @@ struct ConvertScatterToTensorRTScatterElements
38943899
auto newOp = trtRewriter.checkAndCreate<tensorrt::ScatterElementsOp>(
38953900
op->getLoc(), targetTrtMajorVersion,
38963901
/*data*/ adaptor.getInputs().front(),
3897-
/*indices*/ newIndices, /*updates*/ newUpdates.getResult(),
3902+
/*indices*/ newIndices.getResult(),
3903+
/*updates*/ newUpdates.getResult(),
38983904
/*axis*/ rewriter.getI64IntegerAttr(axis));
38993905
if (!newOp)
39003906
return failure();
@@ -4327,24 +4333,32 @@ struct DynamicUpdateSliceToConcatConverter
43274333
// start and shape to be the values appropriate for !hasNonZeroUpdateStart
43284334
// (static case). We will update them in the condition block.
43294335
// Calculate the slice start = update offset + update size.
4330-
TypedValue<RankedTensorType> concatDimOffset =
4331-
trtRewriter.checkAndCreate<tensorrt::ElementWiseOp>(
4332-
loc, targetTrtMajorVersion, updateStartOffset,
4333-
tensorrt::createConstShapeTensor(
4334-
rewriter, loc,
4335-
{static_cast<int32_t>(updateType.getDimSize(*concatAxis))}),
4336-
tensorrt::ElementWiseOperation::kSUM);
4336+
auto sliceStart = trtRewriter.checkAndCreate<tensorrt::ElementWiseOp>(
4337+
loc, targetTrtMajorVersion, updateStartOffset,
4338+
tensorrt::createConstShapeTensor(
4339+
rewriter, loc,
4340+
{static_cast<int32_t>(updateType.getDimSize(*concatAxis))}),
4341+
tensorrt::ElementWiseOperation::kSUM);
4342+
if (!sliceStart)
4343+
return failure();
4344+
TypedValue<RankedTensorType> concatDimOffset = sliceStart.getResult();
4345+
43374346
TypedValue<RankedTensorType> endOffset = tensorrt::scatterShapeTensor(
43384347
rewriter, loc, SmallVector<int64_t>(updateType.getRank(), 0),
43394348
*concatAxis, concatDimOffset);
43404349
// Calculate the slice size = result shape - update offset.
4341-
TypedValue<RankedTensorType> finalPartDimSize =
4350+
auto finalPartDimSizeOp =
43424351
trtRewriter.checkAndCreate<tensorrt::ElementWiseOp>(
43434352
loc, targetTrtMajorVersion,
43444353
tensorrt::createConstShapeTensor(
43454354
rewriter, loc,
43464355
{static_cast<int32_t>(resultType.getDimSize(*concatAxis))}),
43474356
concatDimOffset, tensorrt::ElementWiseOperation::kSUB);
4357+
if (!finalPartDimSizeOp)
4358+
return failure();
4359+
TypedValue<RankedTensorType> finalPartDimSize =
4360+
finalPartDimSizeOp.getResult();
4361+
43484362
TypedValue<RankedTensorType> endShape = tensorrt::scatterShapeTensor(
43494363
rewriter, loc, resultType.getShape(), *concatAxis, finalPartDimSize);
43504364

mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-control-flow.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt=convert-loops | FileCheck %s
1+
// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="convert-loops=true trt-major-version=10" | FileCheck %s
22

33
func.func @while() -> tensor<i32> {
44
%arg0 = stablehlo.constant dense<0> : tensor<i32>

mlir-tensorrt/test/Conversion/StablehloToTensorRT/stablehlo-to-tensorrt-invalid.mlir

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt -verify-diagnostics | FileCheck %s
1+
// RUN: mlir-tensorrt-opt -split-input-file %s --convert-stablehlo-to-tensorrt="trt-major-version=10" -verify-diagnostics | FileCheck %s
22

33
func.func @stablehlo_all_reduce_region(%arg0 : tensor<f32>) -> tensor<f32> {
44
%0 = "stablehlo.all_reduce"(%arg0) ({

0 commit comments

Comments
 (0)