Skip to content

Commit 192602f

Browse files
authored
[DispatchCreation] Limit attention broadcast fusion (#24127)
Reject fusions that would broadcast away an entire attention dim group. Closes #24051 Signed-off-by: Ian Wood <ianwood@u.northwestern.edu>
1 parent bae8f56 commit 192602f

2 files changed

Lines changed: 101 additions & 1 deletion

File tree

compiler/src/iree/compiler/DispatchCreation/ElementwiseOpFusion.cpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
15+
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1516
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
17+
#include "iree/compiler/Dialect/LinalgExt/Utils/IndexingUtils.h"
1618
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
1719
#include "iree/compiler/DispatchCreation/FusionUtils.h"
1820
#include "iree/compiler/DispatchCreation/Passes.h"
@@ -47,6 +49,56 @@ static SmallVector<T> applyProjectedPermutation(const SmallVectorImpl<T> &input,
4749
return result;
4850
}
4951

52+
static bool wouldBroadcastAwayAttentionDimGroup(OpOperand *fusedOperand) {
53+
auto attentionOp =
54+
dyn_cast<IREE::LinalgExt::AttentionOp>(fusedOperand->getOwner());
55+
if (!attentionOp) {
56+
return false;
57+
}
58+
59+
auto producer = dyn_cast_if_present<linalg::LinalgOp>(
60+
fusedOperand->get().getDefiningOp());
61+
if (!producer || !IREE::LinalgExt::isBroadcastingOp(producer)) {
62+
return false;
63+
}
64+
65+
FailureOr<IREE::LinalgExt::AttentionOpDetail> maybeOpInfo =
66+
IREE::LinalgExt::AttentionOpDetail::get(
67+
attentionOp.getQueryMap(), attentionOp.getKeyMap(),
68+
attentionOp.getValueMap(), attentionOp.getOutputMap());
69+
if (failed(maybeOpInfo)) {
70+
return true;
71+
}
72+
IREE::LinalgExt::AttentionOpDetail opInfo = std::move(*maybeOpInfo);
73+
74+
AffineMap producerInputMap =
75+
producer.getMatchingIndexingMap(producer.getDpsInputOperand(0));
76+
AffineMap producerResultMap =
77+
producer.getMatchingIndexingMap(producer.getDpsInitOperand(0));
78+
AffineMap consumerInputMap = attentionOp.getMatchingIndexingMap(fusedOperand);
79+
AffineMap fusedInputMap =
80+
producerInputMap.compose(inversePermutation(producerResultMap))
81+
.compose(consumerInputMap);
82+
83+
auto broadcastsAwayDimGroup = [&](ArrayRef<int64_t> dims) {
84+
auto usesAnyDim = [](AffineMap map, ArrayRef<int64_t> dims) {
85+
return llvm::any_of(map.getResults(), [&](AffineExpr expr) {
86+
auto dimExpr = dyn_cast<AffineDimExpr>(expr);
87+
return dimExpr && llvm::any_of(dims, [&](int64_t dim) {
88+
return dimExpr.getPosition() == dim;
89+
});
90+
});
91+
};
92+
return !dims.empty() && usesAnyDim(consumerInputMap, dims) &&
93+
!usesAnyDim(fusedInputMap, dims);
94+
};
95+
96+
SmallVector<ArrayRef<int64_t>> importantDimGroups = {
97+
opInfo.getBatchDims(), opInfo.getMDims(), opInfo.getK1Dims(),
98+
opInfo.getK2Dims(), opInfo.getNDims()};
99+
return llvm::any_of(importantDimGroups, broadcastsAwayDimGroup);
100+
}
101+
50102
//===----------------------------------------------------------------------===//
51103
// GatherFusionPattern
52104
//===----------------------------------------------------------------------===//
@@ -192,7 +244,13 @@ void ElementwiseOpFusionPass::runOnOperation() {
192244
Operation *producer = fusedOperand->get().getDefiningOp();
193245
Operation *consumer = fusedOperand->getOwner();
194246

195-
return IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer});
247+
if (!IREE::Flow::isNonNullAndOutsideDispatch({producer, consumer})) {
248+
return false;
249+
}
250+
if (wouldBroadcastAwayAttentionDimGroup(fusedOperand)) {
251+
return false;
252+
}
253+
return true;
196254
};
197255
RewritePatternSet linalgExtFusionPatterns(context);
198256
IREE::LinalgExt::populateFuseLinalgExtOpsWithTransposes(

compiler/src/iree/compiler/DispatchCreation/test/elementwise_op_fusion.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,48 @@ util.func public @fuse_attention_with_broadcast(%arg0: tensor<4x8x128x?xf16>, %a
284284
// CHECK-SAME: ins(%[[ARG1]], %[[ARG2]], %[[ARG0]], %[[ARG3]], %[[ARG4]] :
285285
// CHECK: util.return %[[ATTENTION]]
286286

287+
// -----
288+
289+
util.func public @dont_fuse_attention_with_broadcasted_away_n_dim(
290+
%q: tensor<4x32x64x16xf16>,
291+
%k: tensor<4x32x64x16xf16>,
292+
%v: tensor<4x32x64xf16>,
293+
%scale: f16) -> tensor<4x32x64x128xf16> {
294+
%empty_v = tensor.empty() : tensor<4x32x64x128xf16>
295+
%v_broadcast = linalg.generic {
296+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>,
297+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>],
298+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
299+
ins(%v : tensor<4x32x64xf16>) outs(%empty_v : tensor<4x32x64x128xf16>) {
300+
^bb0(%in: f16, %out: f16):
301+
linalg.yield %in : f16
302+
} -> tensor<4x32x64x128xf16>
303+
%empty_out = tensor.empty() : tensor<4x32x64x128xf16>
304+
%attention = iree_linalg_ext.attention {
305+
indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>,
306+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d3)>,
307+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>,
308+
affine_map<(d0, d1, d2, d3, d4, d5) -> ()>,
309+
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d5)>]}
310+
ins(%q, %k, %v_broadcast, %scale :
311+
tensor<4x32x64x16xf16>, tensor<4x32x64x16xf16>,
312+
tensor<4x32x64x128xf16>, f16)
313+
outs(%empty_out : tensor<4x32x64x128xf16>) {
314+
^bb0(%score: f16):
315+
iree_linalg_ext.yield %score : f16
316+
} -> tensor<4x32x64x128xf16>
317+
util.return %attention : tensor<4x32x64x128xf16>
318+
}
319+
// CHECK-LABEL: func public @dont_fuse_attention_with_broadcasted_away_n_dim
320+
// CHECK-SAME: %[[Q:[a-zA-Z0-9]+]]:
321+
// CHECK-SAME: %[[K:[a-zA-Z0-9]+]]:
322+
// CHECK-SAME: %[[V:[a-zA-Z0-9]+]]:
323+
// CHECK-SAME: %[[SCALE:[a-zA-Z0-9]+]]:
324+
// CHECK: %[[V_BCAST:.+]] = linalg.generic
325+
// CHECK: %[[ATTENTION:.+]] = iree_linalg_ext.attention
326+
// CHECK-SAME: ins(%[[Q]], %[[K]], %[[V_BCAST]], %[[SCALE]] :
327+
// CHECK: util.return %[[ATTENTION]]
328+
287329

288330
// -----
289331

0 commit comments

Comments
 (0)