@@ -265,9 +265,15 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
265265 cb (g_diff, " g_diff" , il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
266266
267267 ggml_tensor * g_diff_exp = ggml_exp (ctx0, g_diff);
268- ggml_tensor * key_gdiff = ggml_mul (ctx0, k, g_diff_exp);
268+ ggml_tensor * g_diff_exp_t = ggml_reshape_4d (ctx0, g_diff_exp,
269+ 1 , chunk_size, n_chunks, g_diff_exp->ne [3 ]);
270+
271+ ggml_tensor * key_gdiff = ggml_mul (ctx0, k, g_diff_exp_t );
269272 cb (key_gdiff, " key_gdiff" , il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
270273
274+ ggml_tensor * key_gdiff_t = ggml_cont (ctx0, ggml_transpose (ctx0, key_gdiff));
275+ cb (key_gdiff_t , " key_gdiff_t" , il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
276+
271277
272278 // state to be updated per chunk
273279 ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
@@ -322,9 +328,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
322328 : ggml_concat (ctx0, core_attn_out, core_attn_out_chunk, 2 );
323329
324330 // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
325- ggml_tensor * k_gdiff = ggml_cont (ctx0, get_slice_2d (ctx0, key_gdiff , chunk) );
331+ ggml_tensor * k_gdiff_t = get_slice_2d (ctx0, key_gdiff_t , chunk);
326332 // ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
327- ggml_tensor * kgdmulvnew = ggml_mul_mat (ctx0, v_new_t , ggml_cont (ctx0, ggml_transpose (ctx0, k_gdiff)) );
333+ ggml_tensor * kgdmulvnew = ggml_mul_mat (ctx0, v_new_t , k_gdiff_t );
328334
329335 // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
330336 ggml_tensor * gexp_last_chunk = ggml_cont (ctx0, get_slice_2d (ctx0, g_last_exp, chunk));
0 commit comments