Skip to content

Commit 23f2619

Browse files
authored
Matmul - Add Support for 2D DRAM interleaved in0 + batched height sharded in1 (#37681)
### Ticket #37403 ### Problem description For Deepseek MLA, prefill must reuse the same weights as decode. Optimised decode will require sharded weights, while prefill has assumed interleaved inputs. To reuse the sharded weights, prefill will need to run with interleaved activations and sharded weights. For the batched matmuls, this combination of inputs was not supported. Based on past experience in Llama, it's expected this will slow down matmuls by 30%. ### What's changed This PR adds support for the specific case where in1 is a height sharded batched matmul, where the sharding cleanly splits in1 along B/num_banks. The in1 reader is updated to correctly index the data, while in0/compute remain unchanged. The PR also includes a test to exercise all of the prefill matmuls with this interleaved + sharded pattern, with the sequence length left at the minimum (128) due to the test time required for the larger sequence lengths. These matmuls were profiled and compared to interleaved+interleaved matmuls with no program config (i.e. current prefill implementation) at sequence lengths up to 8k (the larger 32k and 128k sequence lengths are too unwieldy to profile). The relative slowdown worsens as sequence length increases. The best case speedup is 0.47x, the worse case slowdown is 1.16x. Also note that the wkv_b1 numbers are assuming using 8 DRAM banks. Adding this support is in progress and should be complete before prefill perf becomes a focus. However, if wkv_b1 is instead padded for 12 DRAM banks (as it will be for the very first implementation), it will be about 1.5x slower than currently shown. This does not impact the other matmuls. <img width="995" height="608" alt="image" src="https://github.com/user-attachments/assets/d41fe9d2-bdc8-4d64-b7d2-898b38c4eb61" /> ### Checklist - [x] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=edwinlee/37403/deepseek_prefill_sharded)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:edwinlee/37403/deepseek_prefill_sharded) - [x] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=edwinlee/37403/deepseek_prefill_sharded)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:edwinlee/37403/deepseek_prefill_sharded) - [ ] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=edwinlee/37403/deepseek_prefill_sharded)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:edwinlee/37403/deepseek_prefill_sharded) - [ ] New/Existing tests provide coverage for changes #### Model tests If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers `models-mandatory` and `models-extended` presets. The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR. - [ ] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch=edwinlee/37403/deepseek_prefill_sharded)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:edwinlee/37403/deepseek_prefill_sharded) - [ ] `models-mandatory` preset (runs: [Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) and [Frequent model and ttnn tests](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) tests) - [ ] other selection - specify runs - [ ] [![(T3K) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml/badge.svg?branch=edwinlee/37403/deepseek_prefill_sharded)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml?query=branch:edwinlee/37403/deepseek_prefill_sharded) - [ ] `models-mandatory` preset (runs: [Unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-model-perf-tests.yaml) tests) - [ ] other selection - specify runs - [ ] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch=edwinlee/37403/deepseek_prefill_sharded)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:edwinlee/37403/deepseek_prefill_sharded) - [ ] `models-mandatory` preset (runs: [Quick tests](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-perf-tests.yaml) tests) - [ ] other selection - specify runs
1 parent be9fb56 commit 23f2619

File tree

5 files changed

+475
-74
lines changed

5 files changed

+475
-74
lines changed

tests/ttnn/unit_tests/operations/matmul/test_matmul_deepseek.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,3 +756,278 @@ def test_matmul_batched_dram_sharded_program_cache(device, batch, m, k, n):
756756
)
757757

758758
assert device.num_program_cache_entries() == 1
759+
760+
761+
@pytest.mark.parametrize(
762+
"test_case",
763+
[
764+
# Unbatched matmuls - in0 DRAM interleaved, in1 DRAM WIDTH sharded across 12 banks
765+
# Uses MatmulMultiCoreReuseMultiCastProgramConfig (2D multicast)
766+
# qkv_a: K=896, N=2112
767+
{
768+
"batch": 1,
769+
"k": 896,
770+
"n": 2112,
771+
"in1_dtype": ttnn.bfloat8_b,
772+
"expected_pcc": 0.999,
773+
},
774+
# wq_b: K=1536, N=3072
775+
{
776+
"batch": 1,
777+
"k": 1536,
778+
"n": 3072,
779+
"in1_dtype": ttnn.bfloat8_b,
780+
"expected_pcc": 0.999,
781+
},
782+
# wo: K=16384, N=896
783+
{
784+
"batch": 1,
785+
"k": 16384,
786+
"n": 896,
787+
"in1_dtype": ttnn.bfloat8_b,
788+
"expected_pcc": 0.999,
789+
},
790+
# Batched matmuls - in0 DRAM interleaved, in1 DRAM HEIGHT sharded
791+
# wkv_b1 (8 banks): batch=16, 2 per bank, no padding
792+
{
793+
"batch": 16,
794+
"k": 128,
795+
"n": 512,
796+
"in1_dtype": ttnn.bfloat8_b,
797+
"expected_pcc": 0.9997,
798+
"num_dram_banks": 8,
799+
},
800+
# wkv_b2 (12 banks): batch=128, pad to 132
801+
{
802+
"batch": 128,
803+
"k": 512,
804+
"n": 128,
805+
"in1_dtype": ttnn.bfloat8_b,
806+
"expected_pcc": 0.9997,
807+
"num_dram_banks": 12,
808+
},
809+
],
810+
ids=[
811+
"qkv_a",
812+
"wq_b",
813+
"wo",
814+
"wkv_b1_8banks",
815+
"wkv_b2_12banks",
816+
],
817+
)
818+
@pytest.mark.parametrize("seq_len", [128]) # Longer sequence lengths are 1024, 4096, 8192, 32768, 131072
819+
@skip_for_blackhole("Deepseek tests target Wormhole")
820+
def test_prefill_mm_interleaved_sharded(device, test_case, seq_len):
821+
"""
822+
Tests the MLA prefill matmuls with in0 DRAM interleaved and in1 DRAM sharded.
823+
Uses MatmulMultiCoreReuseMultiCastProgramConfig (2D multicast).
824+
This exercises the prefill when forced to use the decode optimised weight sharding
825+
i.e. in1 is DRAM sharded - width for unbatched, and height (by batch) for batched matmuls
826+
"""
827+
torch.manual_seed(0)
828+
829+
batch = test_case["batch"]
830+
k = test_case["k"]
831+
n = test_case["n"]
832+
in1_dtype = test_case["in1_dtype"]
833+
expected_pcc = test_case["expected_pcc"]
834+
tile_w = 32
835+
tile_h = 32
836+
num_dram_banks = test_case.get("num_dram_banks", 12)
837+
838+
device_banks = device.dram_grid_size().x
839+
840+
if device_banks < num_dram_banks:
841+
pytest.skip("Device has less DRAM banks than required for test")
842+
843+
if batch == 1:
844+
# --- Unbatched: in1 DRAM WIDTH sharded ---
845+
n_padded = pad_to_dram_banks(n, tile_w, tile_w * num_dram_banks)
846+
847+
in0_shape = [1, 1, seq_len, k]
848+
in1_shape = [1, 1, k, n]
849+
850+
in0 = torch.randn(in0_shape, dtype=torch.bfloat16)
851+
in1 = torch.randn(in1_shape, dtype=torch.bfloat16)
852+
853+
in0_t = ttnn.from_torch(
854+
in0,
855+
dtype=ttnn.bfloat16,
856+
layout=ttnn.TILE_LAYOUT,
857+
device=device,
858+
memory_config=ttnn.DRAM_MEMORY_CONFIG,
859+
)
860+
861+
in1_shard_shape = [k, n_padded // num_dram_banks]
862+
in1_shard_grid = ttnn.CoreRangeSet(
863+
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_dram_banks - 1, 0))}
864+
)
865+
in1_shard_spec = ttnn.ShardSpec(in1_shard_grid, in1_shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
866+
in1_memory_config = ttnn.MemoryConfig(
867+
ttnn.TensorMemoryLayout.WIDTH_SHARDED, ttnn.BufferType.DRAM, in1_shard_spec
868+
)
869+
in1_t = ttnn.from_torch(
870+
in1,
871+
dtype=in1_dtype,
872+
layout=ttnn.TILE_LAYOUT,
873+
device=device,
874+
memory_config=in1_memory_config,
875+
)
876+
else:
877+
# --- Batched: in1 DRAM HEIGHT sharded by batch (matching test_matmul_batched_dram_sharded) ---
878+
batch_padded = pad_batch_to_dram_banks(batch, num_dram_banks)
879+
batches_per_bank = batch_padded // num_dram_banks
880+
k_padded = pad_to_tile(k, tile_w)
881+
n_padded = pad_to_tile(n, tile_w)
882+
883+
in0_orig = torch.randn([1, batch, seq_len, k], dtype=torch.bfloat16)
884+
in1_orig = torch.randn([1, batch, k, n], dtype=torch.bfloat16)
885+
886+
in0 = torch.zeros([1, batch_padded, seq_len, k_padded], dtype=torch.bfloat16)
887+
in0[:, :batch, :seq_len, :k] = in0_orig
888+
in1 = torch.zeros([1, batch_padded, k_padded, n_padded], dtype=torch.bfloat16)
889+
in1[:, :batch, :k, :n] = in1_orig
890+
891+
in0_t = ttnn.from_torch(
892+
in0,
893+
dtype=ttnn.bfloat16,
894+
layout=ttnn.TILE_LAYOUT,
895+
device=device,
896+
memory_config=ttnn.DRAM_MEMORY_CONFIG,
897+
)
898+
899+
dram_shard_grid = ttnn.CoreRangeSet(
900+
{ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(num_dram_banks - 1, 0))}
901+
)
902+
in1_shard_shape = [batches_per_bank * k_padded, n_padded]
903+
in1_shard_spec = ttnn.ShardSpec(dram_shard_grid, in1_shard_shape, ttnn.ShardOrientation.ROW_MAJOR)
904+
in1_memory_config = ttnn.MemoryConfig(
905+
ttnn.TensorMemoryLayout.HEIGHT_SHARDED, ttnn.BufferType.DRAM, in1_shard_spec
906+
)
907+
in1_t = ttnn.from_torch(
908+
in1,
909+
dtype=in1_dtype,
910+
layout=ttnn.TILE_LAYOUT,
911+
device=device,
912+
memory_config=in1_memory_config,
913+
)
914+
915+
# Compute 2D grid size
916+
M_tiles = seq_len // tile_h
917+
K_tiles = k // tile_w
918+
N_tiles = n // tile_w
919+
920+
# grid_x splits N dimension; grid_y splits M dimension
921+
# grid_x only needs to divide N_tiles (not K_tiles)
922+
grid_x = 1
923+
for x in range(min(8, N_tiles), 0, -1):
924+
if N_tiles % x == 0:
925+
grid_x = x
926+
break
927+
928+
grid_y = 1
929+
for y in range(min(8, M_tiles), 0, -1):
930+
if M_tiles % y == 0:
931+
grid_y = y
932+
break
933+
934+
grid_size = (grid_x, grid_y)
935+
936+
in0_block_h = M_tiles // grid_y
937+
out_block_h = in0_block_h
938+
out_block_w = N_tiles // grid_x
939+
940+
# in0_block_w (inner dim block) must divide K_tiles.
941+
# Target: keep in1 CB tiles (in0_block_w * out_block_w * 2) under ~256 tiles
942+
# while maximizing in0_block_w to minimize inner loop iterations.
943+
max_in1_cb_tiles = 256 # ~140 KB for bfloat8_b with double buffering
944+
in0_block_w = K_tiles
945+
while in0_block_w > 1:
946+
if K_tiles % in0_block_w == 0 and in0_block_w * out_block_w * 2 <= max_in1_cb_tiles:
947+
break
948+
in0_block_w -= 1
949+
# Ensure in0 CB is also reasonable
950+
while in0_block_w > 1 and in0_block_h * in0_block_w * 2 > max_in1_cb_tiles:
951+
in0_block_w = in0_block_w // 2
952+
if K_tiles % in0_block_w != 0:
953+
# Find next valid divisor
954+
while in0_block_w > 1 and K_tiles % in0_block_w != 0:
955+
in0_block_w -= 1
956+
957+
# Determine out_block_h to make sure the matmul will fit in L1.
958+
# CBs that scale with out_block_h * out_block_w:
959+
# interm0 (fp32): out_block_h * out_block_w * 4096 bytes
960+
# output (bf16): out_block_h * out_block_w * 2048 bytes
961+
# in0 (bf16): out_block_h * in0_block_w * 2 * 2048 bytes (double-buffered)
962+
# Target: total CB usage < ~1.2 MB (leaving headroom in 1.5 MB L1)
963+
max_l1_usage = 1200 * 1024 # ~1.2 MB target
964+
per_core_M = out_block_h
965+
per_core_N = out_block_w
966+
actual_out_block_h = out_block_h
967+
for candidate_h in range(out_block_h, 0, -1):
968+
if out_block_h % candidate_h != 0:
969+
continue
970+
# Estimate CB sizes
971+
in0_cb = candidate_h * in0_block_w * 2 * 2048 # bf16 double-buffered
972+
in1_cb = out_block_w * in0_block_w * 2 * 1088 # bf8b double-buffered
973+
interm0_cb = candidate_h * out_block_w * 4096 # fp32
974+
output_cb = candidate_h * out_block_w * 2048 # bf16
975+
total = in0_cb + in1_cb + interm0_cb + output_cb
976+
if total <= max_l1_usage:
977+
actual_out_block_h = candidate_h
978+
break
979+
980+
# Subblock calculation (h * w <= 4, hardware limit)
981+
out_subblock_w = 1
982+
for sw in range(min(out_block_w, 4), 0, -1):
983+
if out_block_w % sw == 0:
984+
out_subblock_w = sw
985+
break
986+
max_sh = 4 // out_subblock_w
987+
out_subblock_h = 1
988+
for sh in range(min(actual_out_block_h, max_sh), 0, -1):
989+
if actual_out_block_h % sh == 0:
990+
out_subblock_h = sh
991+
break
992+
993+
program_config = ttnn.MatmulMultiCoreReuseMultiCastProgramConfig(
994+
compute_with_storage_grid_size=grid_size,
995+
in0_block_w=in0_block_w,
996+
out_subblock_h=out_subblock_h,
997+
out_subblock_w=out_subblock_w,
998+
out_block_h=actual_out_block_h,
999+
out_block_w=out_block_w,
1000+
per_core_M=per_core_M,
1001+
per_core_N=per_core_N,
1002+
transpose_mcast=False,
1003+
fused_activation=None,
1004+
fuse_batch=(batch == 1),
1005+
)
1006+
1007+
compute_kernel_config = ttnn.init_device_compute_kernel_config(
1008+
device.arch(),
1009+
math_fidelity=ttnn.MathFidelity.LoFi,
1010+
math_approx_mode=True,
1011+
fp32_dest_acc_en=True,
1012+
packer_l1_acc=True,
1013+
)
1014+
1015+
output_t = ttnn.matmul(
1016+
in0_t,
1017+
in1_t,
1018+
program_config=program_config,
1019+
memory_config=ttnn.DRAM_MEMORY_CONFIG,
1020+
dtype=ttnn.bfloat16,
1021+
compute_kernel_config=compute_kernel_config,
1022+
)
1023+
1024+
# Validate
1025+
output_tensor = ttnn.to_torch(output_t)
1026+
if batch == 1:
1027+
pt_out = torch.matmul(in0, in1)
1028+
else:
1029+
output_tensor = output_tensor[:, :batch, :seq_len, :n]
1030+
pt_out = torch.matmul(in0_orig, in1_orig)
1031+
1032+
pcc_passed, pcc_message = comp_pcc(pt_out, output_tensor, expected_pcc)
1033+
assert pcc_passed, f"PCC check failed: {pcc_message}"

0 commit comments

Comments
 (0)