-
Notifications
You must be signed in to change notification settings - Fork 77
[TTIR] Add SDPA fusing pattern #5999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| auto queryType = mlir::cast<RankedTensorType>(op.getQuery().getType()); | ||
|
|
||
| // If sequence length of query is 1, we should use the decode op | ||
| if (queryType.getDimSize(2) == 1) { | ||
| // Permute query: [B, H, 1, D] -> [1, B, H, D] | ||
| Value permutedQuery = ttir_to_ttnn::utils::generatePermute( | ||
| mlir::cast<TypedValue<mlir::RankedTensorType>>(adaptor.getQuery()), | ||
| {2, 0, 1, 3}, rewriter, op.getLoc()); | ||
|
|
||
| // Broadcast mask head dimension if needed for decode op | ||
| Value attentionMask = adaptor.getAttentionMask(); | ||
| if (attentionMask) { | ||
| auto maskType = mlir::cast<RankedTensorType>(attentionMask.getType()); | ||
| int64_t numHeads = queryType.getDimSize(1); | ||
|
|
||
| if (numHeads > 1) { | ||
| SmallVector<int64_t> broadcastShape(maskType.getShape().begin(), | ||
| maskType.getShape().end()); | ||
| broadcastShape[2] = numHeads; | ||
|
|
||
| auto broadcastType = ttnn::utils::RankedTensorTypeFactory::create( | ||
| maskType, broadcastShape); | ||
| auto broadcastDims = ttmlir::utils::getBroadcastDimensions<int64_t>( | ||
| maskType.getShape(), broadcastShape); | ||
| auto shapeAttr = | ||
| ttnn::ShapeAttr::get(rewriter.getContext(), broadcastDims); | ||
| attentionMask = rewriter.create<ttnn::RepeatOp>( | ||
| op.getLoc(), broadcastType, attentionMask, shapeAttr); | ||
| } | ||
| } | ||
|
|
||
| auto decodeOp = rewriter.create<ttnn::ScaledDotProductAttentionDecodeOp>( | ||
| op.getLoc(), permutedQuery.getType(), permutedQuery, adaptor.getKey(), | ||
| adaptor.getValue(), op.getIsCausal(), attentionMask, | ||
| /*cur_pos_tensor=*/Value(), /*attention_sink=*/Value(), | ||
| adaptor.getScaleAttr(), /*memory_config=*/nullptr, | ||
| /*program_config=*/nullptr); | ||
|
|
||
| // Permute result back: [1, B, H, D] -> [B, H, 1, D] | ||
| rewriter.replaceOp( | ||
| op, ttir_to_ttnn::utils::generatePermute( | ||
| decodeOp.getResult(), {1, 2, 0, 3}, rewriter, op.getLoc())); | ||
| } else { | ||
| rewriter.replaceOpWithNewOp<ttnn::ScaledDotProductAttentionOp>( | ||
| op, this->getTypeConverter()->convertType(op.getType()), | ||
| adaptor.getQuery(), adaptor.getKey(), adaptor.getValue(), | ||
| adaptor.getAttentionMask(), op.getIsCausal(), adaptor.getScaleAttr(), | ||
| adaptor.getSlidingWindowSizeAttr(), | ||
| /*memory_config=*/nullptr); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in the description, this is probably not the best place to put the logic of deciding whether to create ttnn.scaled_dot_product_attention or ttnn.scaled_dot_product_attention_decode operation so if anyone have a better idea feel free to comment.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #5999 +/- ##
==========================================
- Coverage 69.28% 69.21% -0.08%
==========================================
Files 333 335 +2
Lines 50899 51193 +294
==========================================
+ Hits 35264 35431 +167
- Misses 15635 15762 +127 ☔ View full report in Codecov by Sentry. |
|
@jserbedzijaTT Great! Have you tested this on llama e2e (details)? If so, could you share the resulting IR? |
| OptionalAttr<F32Attr>:$scale, | ||
| OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config); | ||
| OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config, | ||
| OptionalAttr<TTNN_SDPAProgramConfigAttr>:$program_config); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we should pass this information to optimizer (it's missing here).
In general, the optimizer code should be similar to the runtime code.
We also need whether this logic needs to change or not with this modification.
798847d to
04ab34f
Compare
04ab34f to
73aaa10
Compare
Ticket
Link to Github Issue
Problem description
Scaled Dot Product Attention (SDPA) is a fundamental building block in transformer architectures, computing
Currently, we lack a fusing pattern in TTIR to combine this chain of operations into a single
ttir.scaled_dot_product_attentionop.What's changed
SDPA Fusing Pattern:
TTIRFusingpass that identifies the canonical attention computation pattern:ttir.scaled_dot_product_attentionoperation[batch_size, 1, query_seq_len, key_seq_len]if neededLowering of
ttir.scaled_dot_product_attentionttir.scaled_dot_product_attentionin theTTIRToTTNNpass. It now lowers to eitherttnn.scaled_dot_product_attentionorttnn.scaled_dot_product_attention_decodedepending on thesequence_lengthdimension in thequerytensor (ifsequence_lengthis 1, it uses the decode operation).Workarounds:
ttnn.scaled_dot_product_attention_decodeop requiresprogram_configto succeed so we create a default one in the workaround.Testing:
Checklist