@@ -3507,6 +3507,171 @@ def TensorRT_DequantizeOp : TensorRT_Op<"dequantize",
35073507 }];
35083508}
35093509
3510+ //===----------------------------------------------------------------------===//
3511+ // AttentionOp
3512+ //===----------------------------------------------------------------------===//
3513+
3514+ def TensorRT_AttentionOp : TensorRT_Op<"attention",
3515+ [Pure, AttrSizedOperandSegments, TensorRTPartiallyInferTensorResultTypes,
3516+ AllElementTypesMatch<["query", "key", "value"]>,
3517+ AllRanksMatch<["query", "key", "value"]>]>{
3518+ let summary = "TensorRT attention (IAttention) operation";
3519+ let description = [{
3520+ The `tensorrt.attention` operation implements a fused attention mechanism
3521+ that consumes query, key, and value tensors. The operation implicitly includes
3522+ two matrix multiplication layers (BMM1 and BMM2) and a normalization operation
3523+ (typically softmax).
3524+
3525+ By default, TensorRT will try to use a single fused kernel for better efficiency.
3526+ The operation can optionally be decomposed into multiple kernels if no fused
3527+ kernel is available by setting `decomposable` to true.
3528+
3529+ #### Architecture:
3530+
3531+ ```
3532+ Query Key Value Mask (optional) NormalizationQuantizeScale (optional)
3533+ | | | | |
3534+ | Transpose | | |
3535+ | | | | |
3536+ ----BMM1---- | | |
3537+ | | | |
3538+ *--------------------------- |
3539+ | | |
3540+ Normalization | |
3541+ | | |
3542+ *------------------------------------------------
3543+ | |
3544+ -------BMM2------
3545+ |
3546+ Output
3547+ ```
3548+
3549+ #### Inputs:
3550+
3551+ - Query: tensor of type f32, f16, or bf16 with shape
3552+ [batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]
3553+ - Key: tensor of type f32, f16, or bf16 with shape
3554+ [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
3555+ - Value: tensor of type f32, f16, or bf16 with shape
3556+ [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]
3557+ - Mask (optional): tensor of type i1 or same type as BMM1 output with shape
3558+ [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue]
3559+ where batchSize and numHeadsQuery are broadcastable. For i1 mask, true
3560+ indicates the position is allowed to attend. For other types, mask values
3561+ are added to BMM1 output.
3562+ - NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16
3563+ with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output.
3564+ Required when normalization_quantize_to_type is specified.
3565+
3566+ #### Attributes:
3567+
3568+ - normalization_operation: The normalization operation to use (default: kSOFTMAX)
3569+ - causal: Whether to use causal masking (default: false). Cannot be used with mask input.
3570+ - decomposable: Whether the operation can be decomposed (default: false)
3571+ - normalization_quantize_to_type: Optional output type for quantized normalization.
3572+ When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided.
3573+
3574+ #### Constraints:
3575+
3576+ - All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead]
3577+ - Query, key, and value must have the same element type (f32, f16, or bf16)
3578+ - If normalization_quantize_to_type is specified:
3579+ * It must be kFP8 or kINT8
3580+ * normalization_quantize_scale input must be provided
3581+ - If normalization_quantize_scale is provided:
3582+ * normalization_quantize_to_type must be specified
3583+ * Element type must be f32, f16, or bf16
3584+ * Rank must be 0 (scalar) or 1 (1D tensor)
3585+ - Cannot use both mask input and causal=true simultaneously
3586+
3587+ #### Examples:
3588+
3589+ Basic attention:
3590+ ```mlir
3591+ %output = tensorrt.attention ins(%query, %key, %value :
3592+ tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
3593+ -> tensor<2x8x128x64xf16>
3594+ ```
3595+
3596+ Causal attention:
3597+ ```mlir
3598+ %output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value :
3599+ tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>)
3600+ -> tensor<2x8x128x64xf16>
3601+ ```
3602+
3603+ Attention with quantization:
3604+ ```mlir
3605+ %scale = tensorrt.constant dense<1.0> : tensor<f32>
3606+ %output_quant = tensorrt.attention {
3607+ normalization_quantize_to_type = #tensorrt.data_type<kFP8>
3608+ } ins(%query, %key, %value,
3609+ normalization_quantize_scale = %scale :
3610+ tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>,
3611+ tensor<2x8x128x64xf16>, tensor<f32>)
3612+ -> tensor<2x8x128x64xf16>
3613+ ```
3614+ }];
3615+
3616+ let arguments = (ins
3617+ TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query,
3618+ TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key,
3619+ TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value,
3620+ Optional<TensorRT_Tensor>:$mask,
3621+ Optional<TensorRT_RankedTensorOf<[F16, BF16, F32]>>:$normalization_quantize_scale,
3622+ DefaultValuedAttr<TensorRT_AttentionNormalizationOpAttr, "tensorrt::AttentionNormalizationOp::kSOFTMAX">:$normalization_operation,
3623+ DefaultValuedAttr<BoolAttr, "false">:$causal,
3624+ DefaultValuedAttr<BoolAttr, "false">:$decomposable,
3625+ OptionalAttr<TensorRT_DataTypeAttr>:$normalization_quantize_to_type
3626+ );
3627+
3628+ let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result);
3629+
3630+ let assemblyFormat = [{
3631+ attr-dict `ins` `(` $query `,` $key `,` $value
3632+ (`,` `mask` `=` $mask^)?
3633+ (`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)?
3634+ `:` type($query) `,` type($key) `,` type($value)
3635+ (`,` type($mask)^)?
3636+ (`,` type($normalization_quantize_scale)^)?
3637+ `)` `->` type($result)
3638+ }];
3639+
3640+ let hasVerifier = 1;
3641+
3642+ let extraClassDeclaration = [{
3643+ /// Returns true if created op is valid for TensorRT major version.
3644+ bool isValidForTensorRTVersion(int64_t trtMajorVersion);
3645+ }] # baseClassDeclaration;
3646+
3647+ let trtLayerAdd = [{
3648+ nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal);
3649+ if (!layer)
3650+ return failure();
3651+
3652+ if ($mask)
3653+ layer->setMask(*$mask);
3654+
3655+ layer->setDecomposable($decomposable);
3656+
3657+ if ($normalization_quantize_scale) {
3658+ layer->setNormalizationQuantizeScale(*$normalization_quantize_scale);
3659+ }
3660+
3661+ if ($normalization_quantize_to_type) {
3662+ auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type);
3663+ if (!convertedDataType)
3664+ return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum";
3665+ layer->setNormalizationQuantizeToType(*convertedDataType);
3666+ }
3667+
3668+ $results.push_back(layer->getOutput(0));
3669+ #if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(10, 15, 0)
3670+ layer->setMetadata($op);
3671+ #endif
3672+ }];
3673+ }
3674+
35103675//===----------------------------------------------------------------------===//
35113676// TensorRT Dialect Extension Operations
35123677//
0 commit comments