-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Labels
mlir-tensorrtPull request for the mlir-tensorrt projectPull request for the mlir-tensorrt project
Description
For tensorrt.host_tensor
:
module @"ins_x_y_outs_%t19_0" {
func.func @main(%arg0: tensor<?x4xf32> {tensorrt.shape_profile = #tensorrt.shape_profile<min = [2, 4], opt = [4, 4], max = [6, 4]>}, %arg1: tensor<i32> {tensorrt.host_tensor, tensorrt.value_bounds = #tensorrt.shape_profile<min = [1], opt = [2], max = [3]>}) -> tensor<?x?xf32> {
%0 = tensorrt.element_wise <kSUM>(%arg0, %arg0 : tensor<?x4xf32>, tensor<?x4xf32>) -> tensor<?x4xf32>
%1 = tensorrt.shape %0 : tensor<?x4xf32> -> tensor<2xi32>
%2 = tensorrt.slice %1[0][1][1] : tensor<2xi32> to tensor<1xi32>
%3 = tensorrt.collapse_rank %2 : tensor<1xi32> to tensor<i32>
%cst_i32 = tensorrt.constant dense<1> : tensor<i32>
%4 = tensorrt.element_wise <kPROD>(%3, %cst_i32 : tensor<i32>, tensor<i32>) -> tensor<i32>
%5 = tensorrt.slice %1[1][1][1] : tensor<2xi32> to tensor<1xi32>
%6 = tensorrt.collapse_rank %5 : tensor<1xi32> to tensor<i32>
%7 = tensorrt.element_wise <kPROD>(%4, %6 : tensor<i32>, tensor<i32>) -> tensor<i32>
%cst_i32_0 = tensorrt.constant dense<1> : tensor<i32>
%8 = tensorrt.element_wise <kPROD>(%arg1, %cst_i32_0 : tensor<i32>, tensor<i32>) -> tensor<i32>
%9 = tensorrt.element_wise <kFLOOR_DIV>(%7, %8 : tensor<i32>, tensor<i32>) -> tensor<i32>
%cst_i32_1 = tensorrt.constant dense<1> : tensor<1xi32>
%10 = tensorrt.reshape %9 shape(%cst_i32_1: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
%cst_i32_2 = tensorrt.constant dense<1> : tensor<1xi32>
%11 = tensorrt.reshape %arg1 shape(%cst_i32_2: tensor<1xi32>) : tensor<i32> to tensor<?xi32>
%12 = tensorrt.concatenation {axis = 0 : i32} ins(%10, %11 : tensor<?xi32>, tensor<?xi32>) -> tensor<2xi32>
%13 = tensorrt.reshape %0 shape(%12: tensor<2xi32>) : tensor<?x4xf32> to tensor<?x?xf32>
return %13 : tensor<?x?xf32>
}
}
at runtime, I get error:
Input argument 1 validation failed against corresponding function signature arg 1. Reason: InvalidArgument: function expects a memref type with address space device but receieved host
even the 2nd argument has tensorrt.host_tensor
attr. Also the compiled tensorrt engine can recognize the 2nd arg as a shape input.
For tensorrt.shape_profile
:
import nvtripy as tp
import mlir_tensorrt.runtime.api as runtime
def func(x):
x = x + x
return x
compiled_func = tp.compile(func, args=[tp.InputInfo(shape=((2, 4, 6), 4), dtype=tp.float32)])
sig = compiled_func._executable_signature
arg = sig.get_arg(0)
memref = runtime.MemRefType(arg)
print("Shape: ", memref.shape)
bound = sig.get_arg_bound(0)
print(f"Bound: {bound.min()}, {bound.max()}") # shape bounds are empty
Note: the compiled tensorrt engine has the correct shape profiles, but MLIR-TRT executable loses them.
Metadata
Metadata
Assignees
Labels
mlir-tensorrtPull request for the mlir-tensorrt projectPull request for the mlir-tensorrt project