Skip to content

Conversation

@jserbedzijaTT
Copy link
Contributor

Ticket

Link to Github Issue

Problem description

Scaled Dot Product Attention (SDPA) is a fundamental building block in transformer architectures, computing

attention(Q, K, V) = softmax(Q @ K^T * scale + mask) @ V

Currently, we lack a fusing pattern in TTIR to combine this chain of operations into a single ttir.scaled_dot_product_attention op.

What's changed

SDPA Fusing Pattern:

  • Added a pattern in the TTIRFusing pass that identifies the canonical attention computation pattern:
softmax(Q @ K^T * scale + mask) @ V
  • Fuses the matched sequence of operations into a single ttir.scaled_dot_product_attention operation
  • Handles optional attention mask and scaling factors
  • Broadcasts the mask to shape [batch_size, 1, query_seq_len, key_seq_len] if needed

Lowering of ttir.scaled_dot_product_attention

  • Updated the lowering pattern of ttir.scaled_dot_product_attention in the TTIRToTTNN pass. It now lowers to either ttnn.scaled_dot_product_attention or ttnn.scaled_dot_product_attention_decode depending on the sequence_length dimension in the query tensor (if sequence_length is 1, it uses the decode operation).
    • Note: This logic might need improvement. If anyone has a better idea, please feel free to suggest it.

Workarounds:

  • Padding workaround: Pads sequence dimensions to multiples of tile size, then slices output back to original shape
  • SDPA decode config workaround: Currently for some cases ttnn.scaled_dot_product_attention_decode op requires program_config to succeed so we create a default one in the workaround.

Testing:

  • Added Python and MLIR tests for fusing and workaround

Checklist

  • New/Existing tests provide coverage for changes

Comment on lines +2073 to +2122
auto queryType = mlir::cast<RankedTensorType>(op.getQuery().getType());

// If sequence length of query is 1, we should use the decode op
if (queryType.getDimSize(2) == 1) {
// Permute query: [B, H, 1, D] -> [1, B, H, D]
Value permutedQuery = ttir_to_ttnn::utils::generatePermute(
mlir::cast<TypedValue<mlir::RankedTensorType>>(adaptor.getQuery()),
{2, 0, 1, 3}, rewriter, op.getLoc());

// Broadcast mask head dimension if needed for decode op
Value attentionMask = adaptor.getAttentionMask();
if (attentionMask) {
auto maskType = mlir::cast<RankedTensorType>(attentionMask.getType());
int64_t numHeads = queryType.getDimSize(1);

if (numHeads > 1) {
SmallVector<int64_t> broadcastShape(maskType.getShape().begin(),
maskType.getShape().end());
broadcastShape[2] = numHeads;

auto broadcastType = ttnn::utils::RankedTensorTypeFactory::create(
maskType, broadcastShape);
auto broadcastDims = ttmlir::utils::getBroadcastDimensions<int64_t>(
maskType.getShape(), broadcastShape);
auto shapeAttr =
ttnn::ShapeAttr::get(rewriter.getContext(), broadcastDims);
attentionMask = rewriter.create<ttnn::RepeatOp>(
op.getLoc(), broadcastType, attentionMask, shapeAttr);
}
}

auto decodeOp = rewriter.create<ttnn::ScaledDotProductAttentionDecodeOp>(
op.getLoc(), permutedQuery.getType(), permutedQuery, adaptor.getKey(),
adaptor.getValue(), op.getIsCausal(), attentionMask,
/*cur_pos_tensor=*/Value(), /*attention_sink=*/Value(),
adaptor.getScaleAttr(), /*memory_config=*/nullptr,
/*program_config=*/nullptr);

// Permute result back: [1, B, H, D] -> [B, H, 1, D]
rewriter.replaceOp(
op, ttir_to_ttnn::utils::generatePermute(
decodeOp.getResult(), {1, 2, 0, 3}, rewriter, op.getLoc()));
} else {
rewriter.replaceOpWithNewOp<ttnn::ScaledDotProductAttentionOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getQuery(), adaptor.getKey(), adaptor.getValue(),
adaptor.getAttentionMask(), op.getIsCausal(), adaptor.getScaleAttr(),
adaptor.getSlidingWindowSizeAttr(),
/*memory_config=*/nullptr);
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the description, this is probably not the best place to put the logic of deciding whether to create ttnn.scaled_dot_product_attention or ttnn.scaled_dot_product_attention_decode operation so if anyone have a better idea feel free to comment.

@codecov-commenter
Copy link

codecov-commenter commented Nov 23, 2025

Codecov Report

❌ Patch coverage is 57.47801% with 145 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.21%. Comparing base (0bbd65e) to head (73aaa10).
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
lib/Dialect/TTIR/Transforms/TTIRFusing.cpp 30.49% 98 Missing ⚠️
lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp 24.32% 28 Missing ⚠️
lib/Conversion/TTIRToTTNN/Utils.cpp 0.00% 9 Missing ⚠️
lib/Dialect/TTNN/IR/TTNNOps.cpp 80.76% 5 Missing ⚠️
...tProductAttentionPadSequenceDimRewriterPattern.cpp 96.70% 3 Missing ⚠️
include/ttmlir/Target/Utils/MLIRToFlatbuffer.h 90.00% 1 Missing ⚠️
lib/Target/TTNN/TTNNToFlatbuffer.cpp 83.33% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #5999      +/-   ##
==========================================
- Coverage   69.28%   69.21%   -0.08%     
==========================================
  Files         333      335       +2     
  Lines       50899    51193     +294     
==========================================
+ Hits        35264    35431     +167     
- Misses      15635    15762     +127     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@odjuricicTT
Copy link
Contributor

@jserbedzijaTT Great!

Have you tested this on llama e2e (details)? If so, could you share the resulting IR?

OptionalAttr<F32Attr>:$scale,
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config);
OptionalAttr<TTNN_MemoryConfigAttr>:$memory_config,
OptionalAttr<TTNN_SDPAProgramConfigAttr>:$program_config);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should pass this information to optimizer (it's missing here).
In general, the optimizer code should be similar to the runtime code.

We also need whether this logic needs to change or not with this modification.

@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_sdpa_ttir_fusing_pattern branch 2 times, most recently from 798847d to 04ab34f Compare November 25, 2025 15:38
@jserbedzijaTT jserbedzijaTT force-pushed the jserbedzija/add_sdpa_ttir_fusing_pattern branch from 04ab34f to 73aaa10 Compare November 25, 2025 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants