diff --git a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp index 0131f2c71f22..15e37ca44593 100644 --- a/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp +++ b/compiler/plugins/input/Torch/InputConversion/ConvertTMTensorToLinalgExt.cpp @@ -149,33 +149,25 @@ struct AttentionOpConversion tensor::EmptyOp::create(rewriter, loc, outputType.getShape(), outputType.getElementType(), dynSizes); - // TODO: This is a hack. This should be replaced with a simple getScale() - // when support for scaling is plumbed to TMTensor on the torch-mlir side. - // Until then, we are using the default value used in scaled dot product - // attention by PyTorch (most models use the default value because it makes - // the variance of the result of softmax 1 when the mean of Q, K is 0). - // We use scale = 1 / sqrt(d), where d is the head dimension. - // See https://paperswithcode.com/method/scaled for more details. - // - // TODO: We are currently assuming that head dimension is dim = -1. Once we - // have support for batch dims using more general indexing maps, we should - // change this and rely on more general mechanisms. - // TODO: We are currently not handling dynamic shape of head dimensions at - // all. This is because it messes with dispatch formation. This should be - // fixed. - ArrayRef queryShape = op.getQueryType().getShape(); - int64_t headDim = queryShape.back(); - if (headDim == ShapedType::kDynamic) { - return op->emitOpError("NYI: Dynamic head dimension"); - } - - // Attention only works for FloatType. + // Compute scale = 1 / sqrt(headDim), where headDim is the last dimension + // of the query tensor. When headDim is static, fold to a constant. FloatType targetType = cast(op.getQueryType().getElementType()); - - double dk = static_cast(headDim); - dk = 1.0 / std::sqrt(dk); - Value scale = arith::ConstantOp::create( - rewriter, loc, targetType, rewriter.getFloatAttr(targetType, dk)); + int64_t headDim = op.getQueryType().getShape().back(); + Value scale; + if (headDim != ShapedType::kDynamic) { + double dk = 1.0 / std::sqrt(static_cast(headDim)); + scale = arith::ConstantOp::create(rewriter, loc, targetType, + rewriter.getFloatAttr(targetType, dk)); + } else { + int64_t queryRank = op.getQueryType().getRank(); + Value headDimIndex = + tensor::DimOp::create(rewriter, loc, query, queryRank - 1); + Value headDimInt = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), headDimIndex); + Value headDimFloat = + arith::SIToFPOp::create(rewriter, loc, targetType, headDimInt); + scale = math::RsqrtOp::create(rewriter, loc, headDimFloat); + } // Add batches to standard attention indexing maps. SmallVector indexingMaps = diff --git a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir index 59c169a8769b..8903c07d7a7d 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/attention.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/attention.mlir @@ -94,3 +94,65 @@ func.func @attention_dyn(%arg0: tensor, %arg1: tensor, %ar // CHECK: linalg_ext.yield %[[SCORE]] // CHECK: } -> tensor // CHECK: return %[[ATTN]] : tensor + +// ----- +func.func @attention_dynamic_head_dim(%arg0: tensor<5x2x3x?xf32>, %arg1: tensor<5x2x3x?xf32>, %arg2: tensor<5x2x3x?xf32>, %arg3: tensor<5x2x3x?xf32>) -> (tensor<5x2x3x?xf32>) { + %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>) outs(%arg3: tensor<5x2x3x?xf32>) -> tensor<5x2x3x?xf32> + return %0 : tensor<5x2x3x?xf32> +} + +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: func.func @attention_dynamic_head_dim( +// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x?xf32>, %[[ARG1:.*]]: tensor<5x2x3x?xf32>, %[[ARG2:.*]]: tensor<5x2x3x?xf32>, +// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x3x?xf32>) -> tensor<5x2x3x?xf32> { +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK: %[[DIM_V:.*]] = tensor.dim %[[ARG2]], %[[C3]] +// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM_V]]) : tensor<5x2x3x?xf32> +// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C3]] +// CHECK: %[[DIM_INT:.*]] = arith.index_cast %[[DIM]] : index to i64 +// CHECK: %[[DIM_FLOAT:.*]] = arith.sitofp %[[DIM_INT]] : i64 to f32 +// CHECK: %[[SCALE:.*]] = math.rsqrt %[[DIM_FLOAT]] : f32 +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x?xf32>) { +// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32): +// CHECK: linalg_ext.yield %[[SCORE]] +// CHECK: } -> tensor<5x2x3x?xf32> +// CHECK: return %[[ATTN]] : tensor<5x2x3x?xf32> + +// ----- +func.func @attention_fully_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> (tensor) { + %0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor, tensor, tensor) outs(%arg3: tensor) -> tensor + return %0 : tensor +} + +// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)> +// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)> +// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)> +// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()> +// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)> + +// CHECK-LABEL: func.func @attention_fully_dynamic( +// CHECK-SAME: %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, +// CHECK-SAME: %[[ARG3:.*]]: tensor) -> tensor { +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]] +// CHECK: %[[DIM_V:.*]] = tensor.dim %[[ARG2]], %[[C3]] +// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM_V]]) : tensor +// CHECK: %[[DIM_Q:.*]] = tensor.dim %[[ARG0]], %[[C3]] +// CHECK: %[[DIM_INT:.*]] = arith.index_cast %[[DIM_Q]] : index to i64 +// CHECK: %[[DIM_FLOAT:.*]] = arith.sitofp %[[DIM_INT]] : i64 to f32 +// CHECK: %[[SCALE:.*]] = math.rsqrt %[[DIM_FLOAT]] : f32 +// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor, tensor, tensor, f32) outs(%[[EMPTY]] : tensor) { +// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32): +// CHECK: linalg_ext.yield %[[SCORE]] +// CHECK: } -> tensor +// CHECK: return %[[ATTN]] : tensor