Skip to content

Commit a4dc950

Browse files
committed
Add AttentionOp
Fix Attention addLayer, make cmake to work with TRT 10.14
1 parent 326b496 commit a4dc950

13 files changed

Lines changed: 399 additions & 3 deletions

File tree

mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header)
5757
find_file(
5858
trt_python_plugin_header
5959
NAMES NvInferPythonPlugin.h plugin.h
60-
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
61-
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl
60+
HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl
61+
PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl
6262
REQUIRED
6363
NO_CMAKE_PATH NO_DEFAULT_PATH
6464
NO_CACHE

mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
136136

137137
if(ARG_VERSION VERSION_EQUAL "10.13.0.35")
138138
set(ARG_VERSION "10.13.0.35")
139+
140+
if(ARG_VERSION VERSION_EQUAL "10.14")
141+
set(ARG_VERSION "10.14.1.48")
139142
endif()
140143

141144
set(downloadable_versions
@@ -156,6 +159,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_
156159
"10.9.0.34"
157160
"10.12.0.36"
158161
"10.13.0.35"
162+
"10.14.1.48"
159163
)
160164

161165
if(NOT ARG_VERSION IN_LIST downloadable_versions)

mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def test_attributes():
4949
tensorrt.TripLimitAttr.get("kWHILE"),
5050
tensorrt.FillOperationAttr.get("kRANDOM_UNIFORM"),
5151
tensorrt.ScatterModeAttr.get("kELEMENT"),
52+
tensorrt.AttentionNormalizationOpAttr.get("kSOFTMAX"),
53+
tensorrt.DataTypeAttr.get("kFLOAT"),
5254
]:
5355
print(attr)
5456

@@ -74,3 +76,5 @@ def test_attributes():
7476
# CHECK-NEXT: #tensorrt.trip_limit<kWHILE>
7577
# CHECK-NEXT: #tensorrt.fill_operation<kRANDOM_UNIFORM>
7678
# CHECK-NEXT: #tensorrt.scatter_mode<kELEMENT>
79+
# CHECK-NEXT: #tensorrt.attention_normalization_op<kSOFTMAX>
80+
# CHECK-NEXT: #tensorrt.data_type<kFLOAT>

mlir-tensorrt/compiler/tools/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ add_subdirectory(mlir-tensorrt-shlib)
4040
add_subdirectory(mlir-tensorrt-opt)
4141
add_subdirectory(mlir-tensorrt-compiler)
4242
add_subdirectory(mlir-tensorrt-translate)
43-
add_subdirectory(mlir-tensorrt-lsp-server)
43+
# add_subdirectory(mlir-tensorrt-lsp-server)
4444
add_subdirectory(mlir-tensorrt-runner)

mlir-tensorrt/integrations/python/bindings/Compiler/DialectTensorRT.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,6 @@ PYBIND11_MODULE(_tensorrt, m) {
7777
ADD_PYTHON_ATTRIBUTE_ADAPTOR(TripLimit)
7878
ADD_PYTHON_ATTRIBUTE_ADAPTOR(FillOperation)
7979
ADD_PYTHON_ATTRIBUTE_ADAPTOR(ScatterMode)
80+
ADD_PYTHON_ATTRIBUTE_ADAPTOR(AttentionNormalizationOp)
81+
ADD_PYTHON_ATTRIBUTE_ADAPTOR(DataType)
8082
}

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,22 @@ DECLARE_ATTR_GETTER_FROM_STRING(ScatterMode)
188188
DECLARE_IS_ATTR(ScatterMode)
189189
DECLARE_STRING_GETTER_FROM_ATTR(ScatterMode)
190190

191+
//===----------------------------------------------------------------------===//
192+
// AttentionNormalizationOp
193+
//===----------------------------------------------------------------------===//
194+
195+
DECLARE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
196+
DECLARE_IS_ATTR(AttentionNormalizationOp)
197+
DECLARE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)
198+
199+
//===----------------------------------------------------------------------===//
200+
// DataType
201+
//===----------------------------------------------------------------------===//
202+
203+
DECLARE_ATTR_GETTER_FROM_STRING(DataType)
204+
DECLARE_IS_ATTR(DataType)
205+
DECLARE_STRING_GETTER_FROM_ATTR(DataType)
206+
191207
#ifdef __cplusplus
192208
}
193209
#endif

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr<
378378
def TensorRT_ScatterModeAttr : TensorRT_EnumAttr<TensorRT_ScatterMode, "scatter_mode">{
379379
}
380380

381+
def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr<
382+
"AttentionNormalizationOp", "",
383+
[
384+
I32EnumAttrCase<"kNONE", 0>,
385+
I32EnumAttrCase<"kSOFTMAX", 1>
386+
]>
387+
{
388+
let cppNamespace = "::mlir::tensorrt";
389+
let genSpecializedAttr = 0;
390+
}
391+
392+
def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr<TensorRT_AttentionNormalizationOp, "attention_normalization_op">{
393+
}
394+
395+
def TensorRT_DataType : TensorRT_I32EnumAttr<
396+
"DataType", "",
397+
[
398+
I32EnumAttrCase<"kFLOAT", 0>,
399+
I32EnumAttrCase<"kHALF", 1>,
400+
I32EnumAttrCase<"kINT8", 2>,
401+
I32EnumAttrCase<"kINT32", 3>,
402+
I32EnumAttrCase<"kBOOL", 4>,
403+
I32EnumAttrCase<"kUINT8", 5>,
404+
I32EnumAttrCase<"kFP8", 6>,
405+
I32EnumAttrCase<"kBF16", 7>,
406+
I32EnumAttrCase<"kINT64", 8>,
407+
I32EnumAttrCase<"kINT4", 9>,
408+
I32EnumAttrCase<"kFP4", 10>,
409+
I32EnumAttrCase<"kE8M0", 11>
410+
]>
411+
{
412+
let cppNamespace = "::mlir::tensorrt";
413+
let genSpecializedAttr = 0;
414+
}
415+
416+
def TensorRT_DataTypeAttr : TensorRT_EnumAttr<TensorRT_DataType, "data_type">{
417+
}
418+
381419
#endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS

mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,6 +3507,171 @@ def TensorRT_DequantizeOp : TensorRT_Op<"dequantize",
35073507
}];
35083508
}
35093509

3510+
//===----------------------------------------------------------------------===//
3511+
// AttentionOp
3512+
//===----------------------------------------------------------------------===//
3513+
3514+
def TensorRT_AttentionOp : TensorRT_Op<"attention",
3515+
[Pure, AttrSizedOperandSegments, TensorRTPartiallyInferTensorResultTypes,
3516+
AllElementTypesMatch<["query", "key", "value"]>,
3517+
AllRanksMatch<["query", "key", "value"]>]>{
3518+
let summary = "TensorRT attention (IAttention) operation";
3519+
let description = [{
3520+
The `tensorrt.attention` operation implements a fused attention mechanism
3521+
that consumes query, key, and value tensors. The operation implicitly includes
3522+
two matrix multiplication layers (BMM1 and BMM2) and a normalization operation
3523+
(typically softmax).
3524+
3525+
By default, TensorRT will try to use a single fused kernel for better efficiency.
3526+
The operation can optionally be decomposed into multiple kernels if no fused
3527+
kernel is available by setting `decomposable` to true.
3528+
3529+
#### Architecture:
3530+
3531+
```
3532+
Query Key Value Mask (optional) NormalizationQuantizeScale (optional)
3533+
| | | | |
3534+
| Transpose | | |
3535+
| | | | |
3536+
----BMM1---- | | |
3537+
| | | |
3538+
*--------------------------- |
3539+
| | |
3540+
Normalization | |
3541+
| | |
3542+
*------------------------------------------------
3543+
| |
3544+
-------BMM2------
3545+
|
3546+
Output
3547+
```
3548+
3549+
#### Inputs:
3550+
3551+
- Query: tensor of type f32, f16, or bf16 with shape
3552+
[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]
3553+
- Key: tensor of type f32, f16, or bf16 with shape
3554+
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
3555+
- Value: tensor of type f32, f16, or bf16 with shape
3556+
[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
3557+
- Mask (optional): tensor of type i1 or same type as BMM1 output with shape
3558+
[batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]
3559+
where batchSize and numHeadsQuery are broadcastable. For i1 mask, true
3560+
indicates the position is allowed to attend. For other types, mask values
3561+
are added to BMM1 output.
3562+
- NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
3563+
with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output.
3564+
Required when normalization_quantize_to_type is specified.
3565+
3566+
#### Attributes:
3567+
3568+
- normalization_operation: The normalization operation to use (default: kSOFTMAX)
3569+
- causal: Whether to use causal masking (default: false). Cannot be used with mask input.
3570+
- decomposable: Whether the operation can be decomposed (default: false)
3571+
- normalization_quantize_to_type: Optional output type for quantized normalization.
3572+
When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided.
3573+
3574+
#### Constraints:
3575+
3576+
- All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead]
3577+
- Query, key, and value must have the same element type (f32, f16, or bf16)
3578+
- If normalization_quantize_to_type is specified:
3579+
* It must be kFP8 or kINT8
3580+
* normalization_quantize_scale input must be provided
3581+
- If normalization_quantize_scale is provided:
3582+
* normalization_quantize_to_type must be specified
3583+
* Element type must be f32, f16, or bf16
3584+
* Rank must be 0 (scalar) or 1 (1D tensor)
3585+
- Cannot use both mask input and causal=true simultaneously
3586+
3587+
#### Examples:
3588+
3589+
Basic attention:
3590+
```mlir
3591+
%output = tensorrt.attention ins(%query, %key, %value :
3592+
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
3593+
-> tensor<2x8x128x64xf16>
3594+
```
3595+
3596+
Causal attention:
3597+
```mlir
3598+
%output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value :
3599+
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
3600+
-> tensor<2x8x128x64xf16>
3601+
```
3602+
3603+
Attention with quantization:
3604+
```mlir
3605+
%scale = tensorrt.constant dense<1.0> : tensor<f32>
3606+
%output_quant = tensorrt.attention {
3607+
normalization_quantize_to_type = #tensorrt.data_type<kFP8>
3608+
} ins(%query, %key, %value,
3609+
normalization_quantize_scale = %scale :
3610+
tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>,
3611+
tensor<2x8x128x64xf16>, tensor<f32>)
3612+
-> tensor<2x8x128x64xf16>
3613+
```
3614+
}];
3615+
3616+
let arguments = (ins
3617+
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query,
3618+
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key,
3619+
TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
3620+
Optional<TensorRT_Tensor>:$mask,
3621+
Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
3622+
DefaultValuedAttr<TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX">:$normalization_operation,
3623+
DefaultValuedAttr<BoolAttr, "false">:$causal,
3624+
DefaultValuedAttr<BoolAttr, "false">:$decomposable,
3625+
OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
3626+
);
3627+
3628+
let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result);
3629+
3630+
let assemblyFormat = [{
3631+
attr-dict `ins` `(` $query `,` $key `,` $value
3632+
(`,` `mask` `=` $mask^)?
3633+
(`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)?
3634+
`:` type($query) `,` type($key) `,` type($value)
3635+
(`,` type($mask)^)?
3636+
(`,` type($normalization_quantize_scale)^)?
3637+
`)` `->` type($result)
3638+
}];
3639+
3640+
let hasVerifier = 1;
3641+
3642+
let extraClassDeclaration = [{
3643+
/// Returns true if created op is valid for TensorRT major version.
3644+
bool isValidForTensorRTVersion(int64_t trtMajorVersion);
3645+
}] # baseClassDeclaration;
3646+
3647+
let trtLayerAdd = [{
3648+
nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
3649+
if (!layer)
3650+
return failure();
3651+
3652+
if ($mask)
3653+
layer->setMask(*$mask);
3654+
3655+
layer->setDecomposable($decomposable);
3656+
3657+
if ($normalization_quantize_scale) {
3658+
layer->setNormalizationQuantizeScale(*$normalization_quantize_scale);
3659+
}
3660+
3661+
if ($normalization_quantize_to_type) {
3662+
auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
3663+
if (!convertedDataType)
3664+
return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
3665+
layer->setNormalizationQuantizeToType(*convertedDataType);
3666+
}
3667+
3668+
$results.push_back(layer->getOutput(0));
3669+
#if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(10, 15, 0)
3670+
layer->setMetadata($op);
3671+
#endif
3672+
}];
3673+
}
3674+
35103675
//===----------------------------------------------------------------------===//
35113676
// TensorRT Dialect Extension Operations
35123677
//

mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,11 @@ DEFINE_STRING_GETTER_FROM_ATTR(FillOperation)
121121
DEFINE_ATTR_GETTER_FROM_STRING(ScatterMode)
122122
DEFINE_IS_ATTR(ScatterMode)
123123
DEFINE_STRING_GETTER_FROM_ATTR(ScatterMode)
124+
125+
DEFINE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp)
126+
DEFINE_IS_ATTR(AttentionNormalizationOp)
127+
DEFINE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp)
128+
129+
DEFINE_ATTR_GETTER_FROM_STRING(DataType)
130+
DEFINE_IS_ATTR(DataType)
131+
DEFINE_STRING_GETTER_FROM_ATTR(DataType)

mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,3 +915,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion(
915915
return isValidForTensorRTVersionScatterOpImpl(
916916
trtMajorVersion, dataElementType, indicesElementType);
917917
}
918+
919+
//===----------------------------------------------------------------------===//
920+
// AttentionOp
921+
//===----------------------------------------------------------------------===//
922+
923+
bool tensorrt::AttentionOp::isValidForTensorRTVersion(
924+
int64_t trtMajorVersion) {
925+
// IAttention layer is only supported in TensorRT >= 10.14.0
926+
if (trtMajorVersion < 10)
927+
return false;
928+
929+
return true;
930+
}

0 commit comments

Comments
 (0)