Add fused minimal matmul addcmul operation#36502
Conversation
tests/ttnn/nightly/unit_tests/operations/experimental/test_dit_minimal_matmul_addcmul_fused.py
Fixed
Show fixed
Hide fixed
tests/ttnn/nightly/unit_tests/operations/experimental/test_dit_minimal_matmul_addcmul_fused.py
Fixed
Show fixed
Hide fixed
b5a7dad to
c21de53
Compare
There was a problem hiding this comment.
Pull request overview
This PR introduces a fused dit_minimal_matmul_addcmul_fused operation that combines minimal_matmul and addcmul for improved performance in DiT transformer blocks (specifically targeting Wan2.2). The PR also refactors the minimal_matmul device operations by unifying the previously separate minimal_matmul and minimal_matmul_split implementations into a single device operation that supports both single and chunked outputs.
Changes:
- Added new fused operation
dit_minimal_matmul_addcmul_fusedthat computesoutput = residual + scalar * matmul(input, weight) * gate - Unified minimal_matmul and minimal_matmul_split device operations, changing return types from
Tensortostd::vector<Tensor> - Extended kernels to support fused ternary (addcmul) operations with new circular buffers and runtime parameters
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| dit_minimal_matmul_addcmul_fused/* | New operation implementation with nanobind bindings and comprehensive documentation |
| minimal_matmul_device_operation.* | Unified device operation supporting both single and split outputs, added ternary fusion parameters |
| minimal_matmul_program_factory.* | Added circular buffers for ternary inputs, extended runtime argument handling |
| minimal_matmul.cpp, minimal_matmul_split.cpp | Updated to use unified device operation returning vector of tensors |
| kernels/compute.cpp | Added add_bias_and_addcmul_block function implementing fused bias and addcmul logic |
| kernels/dm_in*.cpp | Extended dataflow kernels to read and process ternary input tensors |
| kernels/matmul_dataflow_common.hpp | Added read_ternary_blocks_sync helper for reading ternary tensors |
| minimal_matmul_split_* (deleted) | Removed duplicate device operation files now unified with minimal_matmul |
| CMakeLists.txt | Updated build configuration to remove split-specific files and add new fused operation |
| test_dit_minimal_matmul_addcmul_fused.py | Comprehensive tests covering basic functionality, Wan2.2 shapes, and different scalar values |
ttnn/cpp/ttnn/operations/experimental/minimal_matmul/device/kernels/compute.cpp
Outdated
Show resolved
Hide resolved
ttnn/cpp/ttnn/operations/experimental/minimal_matmul/device/kernels/matmul_dataflow_common.hpp
Outdated
Show resolved
Hide resolved
c21de53 to
a9f6f42
Compare
ttnn/cpp/ttnn/operations/experimental/minimal_matmul/device/minimal_matmul_device_operation.cpp
Outdated
Show resolved
Hide resolved
...l/transformer/dit_minimal_matmul_addcmul_fused/dit_minimal_matmul_addcmul_fused_nanobind.cpp
Show resolved
Hide resolved
...l/transformer/dit_minimal_matmul_addcmul_fused/dit_minimal_matmul_addcmul_fused_nanobind.cpp
Outdated
Show resolved
Hide resolved
ttnn/cpp/ttnn/operations/experimental/minimal_matmul/device/kernels/matmul_dataflow_common.hpp
Show resolved
Hide resolved
ttnn/cpp/ttnn/operations/experimental/minimal_matmul/device/kernels/matmul_dataflow_common.hpp
Outdated
Show resolved
Hide resolved
fcfd4fa to
737d754
Compare
|
/codeowners ping |
|
Hi Borys Bradel (@bbradelTT), Colman Glagovich (@cglagovichTT), Edwin Lee (@edwinleeTT), Izajasz Wrosz (@iwroszTT), Jonathan Su (@jonathansuTT), NSexton (@nsextonTT), this PR Add fused minimal matmul addcmul operation by Nathan Maurice (@nmauriceTT) needs your approval/review to merge this. |
...l/transformer/dit_minimal_matmul_addcmul_fused/dit_minimal_matmul_addcmul_fused_nanobind.cpp
Outdated
Show resolved
Hide resolved
…ernary writing order
b6040c5 to
dd53245
Compare
Ticket
#35915
Problem description
For Wan2.2, we want to fuse minimal_matmul and addcmul, like the following pattern:
What's changed
This adds a new
dit_minimal_matmul_addcmul_fusedoperation that perform minimal_matmul + addcmul operation.This operation is equivalent to:
To make
minimal_matmulmore extensible, I've also merged its device operation with that of minimal_matmul_split (i.e. reduce code duplication).The kernels of dit_minimal_matmul_addcmul_fused have been implemented by modifying the minimal_matmul kernels. It is also defined by calling
minimal_matmul(device operation of minimal_matmul has been updated with new parameters).Note: Ideally, we'd like to use the
addcmul_tileLLK. But it seems thatunary_bcast<BroadcastType::ROW>does not work with fp32_acc_to_dst. Instead, row-broadcast ofternary_ais done throughadd_bcast_tile` (FPU). The downside is that the output should be less accurate than with addcmul.If accuracy turns out to be a problem, then we can switch to other workaround (e.g. do broadcasting in dataflow kernels).
Performance
wan2.2_14b-720p-glx: Single GLX
(M, K, N) = (9472, 1280, 5120)
wan2.2_14b-720p Quad GLX
(M, K, N) = (2368, 1280, 5120)
As a reference, here's the execution time of addcmul.
Fusing the operations saves us ~50-65% of the execution time addcmul.
Checklist
Model tests
If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers
models-mandatoryandmodels-extendedpresets.The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR.
models-mandatorypreset (runs: Device perf regressions and Frequent model and ttnn tests) : Device Perf (same failures as main)models-extendedpreset (runs: the mandatory tests, plus Demo and Model perf tests)main, flux1 validated here)models-mandatorypreset (runs: Unit tests)models-extendedpreset (runs: the mandatory tests, plus Demo and Model perf tests)models-mandatorypreset (runs: Quick tests)models-extendedpreset (runs: the mandatory tests, plus Demo and Model perf tests)