Skip to content

Commit 0879a6a

Browse files
authored
Add initial tuning for M5 pro and max (#3211)
1 parent a9573f9 commit 0879a6a

2 files changed

Lines changed: 16 additions & 0 deletions

File tree

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222

2323
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
2424
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 256, 2, 2) \
25+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 128, 64, 2, 4) \
26+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 128, 256, 2, 4) \
27+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 64, 4, 4) \
28+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 256, 4, 4) \
2529
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 128, 128, 512, 4, 4)
2630

2731
instantiate_gemm_shapes_helper(float16, half, float16, half);

mlx/backend/metal/matmul.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,15 @@ void steel_matmul_regular_axpby_nax(
204204
int bm = 128, bn = 128, bk = 512;
205205
int wm = 4, wn = 4;
206206

207+
// Temp routing for larger devices
208+
char devc = d.get_architecture().back();
209+
if (devc == 's' || devc == 'c' || devc == 'd') {
210+
bk = (K >= 8192 && K > (M + N)) ? 64 : 256;
211+
212+
bm = 64;
213+
wm = 2;
214+
}
215+
207216
// Prepare kernel name
208217
std::ostringstream kname;
209218

@@ -268,6 +277,9 @@ void steel_matmul_regular_axpby_nax(
268277

269278
// TODO: Explore device-based tuning for swizzle
270279
int swizzle_log = tm <= 3 ? 0 : 1;
280+
if (devc == 's' || devc == 'c' || devc == 'd') {
281+
swizzle_log = 2;
282+
}
271283

272284
// Prepare steel matmul params
273285
GEMMParams params{/* const int M = */ M,

0 commit comments

Comments
 (0)