Skip to content

Commit 3cde719

Browse files
authored
Route to gather qmm only for many tokens per expert (#2082)
1 parent 5de6d94 commit 3cde719

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

mlx/backend/metal/quantized.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -850,14 +850,14 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
850850
int M = x.shape(-2);
851851
int N = out.shape(-1);
852852
int B = out.size() / M / N;
853+
int E = w.size() / w.shape(-1) / w.shape(-2);
853854
int vector_limit = transpose_ ? get_qmv_batch_limit(K, N, d) : 4;
854855

855856
// We are walking x in order and w is also in order so we can batch up the
856857
// matmuls and reuse reading x and w.
857858
//
858-
// TODO: Tune 16 here a bit better. Maybe also choose it dynamically based
859-
// on B and (w.size() / K / N).
860-
if (M == 1 && B >= 16 && right_sorted_ == true) {
859+
// TODO: Tune 16 and 8 here a bit better.
860+
if (M == 1 && B >= 16 && right_sorted_ == true && B / E >= 8) {
861861
gather_qmm_rhs(
862862
x,
863863
w,

0 commit comments

Comments
 (0)