Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -149,33 +149,25 @@ struct AttentionOpConversion
tensor::EmptyOp::create(rewriter, loc, outputType.getShape(),
outputType.getElementType(), dynSizes);

// TODO: This is a hack. This should be replaced with a simple getScale()
// when support for scaling is plumbed to TMTensor on the torch-mlir side.
// Until then, we are using the default value used in scaled dot product
// attention by PyTorch (most models use the default value because it makes
// the variance of the result of softmax 1 when the mean of Q, K is 0).
// We use scale = 1 / sqrt(d), where d is the head dimension.
// See https://paperswithcode.com/method/scaled for more details.
//
// TODO: We are currently assuming that head dimension is dim = -1. Once we
// have support for batch dims using more general indexing maps, we should
// change this and rely on more general mechanisms.
// TODO: We are currently not handling dynamic shape of head dimensions at
// all. This is because it messes with dispatch formation. This should be
// fixed.
ArrayRef<int64_t> queryShape = op.getQueryType().getShape();
int64_t headDim = queryShape.back();
if (headDim == ShapedType::kDynamic) {
return op->emitOpError("NYI: Dynamic head dimension");
}

// Attention only works for FloatType.
// Compute scale = 1 / sqrt(headDim), where headDim is the last dimension
// of the query tensor. When headDim is static, fold to a constant.
FloatType targetType = cast<FloatType>(op.getQueryType().getElementType());

double dk = static_cast<double>(headDim);
dk = 1.0 / std::sqrt(dk);
Value scale = arith::ConstantOp::create(
rewriter, loc, targetType, rewriter.getFloatAttr(targetType, dk));
int64_t headDim = op.getQueryType().getShape().back();
Value scale;
if (headDim != ShapedType::kDynamic) {
double dk = 1.0 / std::sqrt(static_cast<double>(headDim));
scale = arith::ConstantOp::create(rewriter, loc, targetType,
rewriter.getFloatAttr(targetType, dk));
} else {
int64_t queryRank = op.getQueryType().getRank();
Value headDimIndex =
tensor::DimOp::create(rewriter, loc, query, queryRank - 1);
Value headDimInt = arith::IndexCastOp::create(
rewriter, loc, rewriter.getI64Type(), headDimIndex);
Value headDimFloat =
arith::SIToFPOp::create(rewriter, loc, targetType, headDimInt);
scale = math::RsqrtOp::create(rewriter, loc, headDimFloat);
}

// Add batches to standard attention indexing maps.
SmallVector<AffineMap> indexingMaps =
Expand Down
62 changes: 62 additions & 0 deletions compiler/plugins/input/Torch/InputConversion/test/attention.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,65 @@ func.func @attention_dyn(%arg0: tensor<?x?x4xf32>, %arg1: tensor<?x?x4xf32>, %ar
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<?x?x4xf32>
// CHECK: return %[[ATTN]] : tensor<?x?x4xf32>

// -----
func.func @attention_dynamic_head_dim(%arg0: tensor<5x2x3x?xf32>, %arg1: tensor<5x2x3x?xf32>, %arg2: tensor<5x2x3x?xf32>, %arg3: tensor<5x2x3x?xf32>) -> (tensor<5x2x3x?xf32>) {
%0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>) outs(%arg3: tensor<5x2x3x?xf32>) -> tensor<5x2x3x?xf32>
return %0 : tensor<5x2x3x?xf32>
}

// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK-LABEL: func.func @attention_dynamic_head_dim(
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x2x3x?xf32>, %[[ARG1:.*]]: tensor<5x2x3x?xf32>, %[[ARG2:.*]]: tensor<5x2x3x?xf32>,
// CHECK-SAME: %[[ARG3:.*]]: tensor<5x2x3x?xf32>) -> tensor<5x2x3x?xf32> {
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[DIM_V:.*]] = tensor.dim %[[ARG2]], %[[C3]]
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM_V]]) : tensor<5x2x3x?xf32>
// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG0]], %[[C3]]
// CHECK: %[[DIM_INT:.*]] = arith.index_cast %[[DIM]] : index to i64
// CHECK: %[[DIM_FLOAT:.*]] = arith.sitofp %[[DIM_INT]] : i64 to f32
// CHECK: %[[SCALE:.*]] = math.rsqrt %[[DIM_FLOAT]] : f32
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>, tensor<5x2x3x?xf32>, f32) outs(%[[EMPTY]] : tensor<5x2x3x?xf32>) {
// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32):
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<5x2x3x?xf32>
// CHECK: return %[[ATTN]] : tensor<5x2x3x?xf32>

// -----
func.func @attention_fully_dynamic(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>, %arg2: tensor<?x?x?x?xf32>, %arg3: tensor<?x?x?x?xf32>) -> (tensor<?x?x?x?xf32>) {
%0 = tm_tensor.attention ins(%arg0, %arg1, %arg2 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) outs(%arg3: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}

// CHECK-DAG: #[[$MAP_Q:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
// CHECK-DAG: #[[$MAP_K:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
// CHECK-DAG: #[[$MAP_V:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
// CHECK-DAG: #[[$MAP_S:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
// CHECK-DAG: #[[$MAP_O:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

// CHECK-LABEL: func.func @attention_fully_dynamic(
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>, %[[ARG2:.*]]: tensor<?x?x?x?xf32>,
// CHECK-SAME: %[[ARG3:.*]]: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK: %[[DIM2:.*]] = tensor.dim %[[ARG0]], %[[C2]]
// CHECK: %[[DIM_V:.*]] = tensor.dim %[[ARG2]], %[[C3]]
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM_V]]) : tensor<?x?x?x?xf32>
// CHECK: %[[DIM_Q:.*]] = tensor.dim %[[ARG0]], %[[C3]]
// CHECK: %[[DIM_INT:.*]] = arith.index_cast %[[DIM_Q]] : index to i64
// CHECK: %[[DIM_FLOAT:.*]] = arith.sitofp %[[DIM_INT]] : i64 to f32
// CHECK: %[[SCALE:.*]] = math.rsqrt %[[DIM_FLOAT]] : f32
// CHECK: %[[ATTN:.*]] = iree_linalg_ext.attention {indexing_maps = [#[[$MAP_Q]], #[[$MAP_K]], #[[$MAP_V]], #[[$MAP_S]], #[[$MAP_O]]]} ins(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[SCALE]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, f32) outs(%[[EMPTY]] : tensor<?x?x?x?xf32>) {
// CHECK: ^[[BLOCK:.+]](%[[SCORE:.+]]: f32):
// CHECK: linalg_ext.yield %[[SCORE]]
// CHECK: } -> tensor<?x?x?x?xf32>
// CHECK: return %[[ATTN]] : tensor<?x?x?x?xf32>
Loading