@@ -382,10 +382,10 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
382382 if (m_config.gate_up_type != LLMMLPNode::GATE_UP_TYPE::SEPARATE) {
383383 N = w_gate.size (0 ) / 2 ;
384384 if (m_config.gate_up_type == LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE) {
385- // When VariadicSplit output[1 ] connects to gate instead of up, swap the pointers
385+ // COMBINED_UP_GATE: VariadicSplit output[0 ] connects to up, output[1] connects to gate
386386 gate_up.setup (w_gate.ptr_v (N, 0 ), w_gate.ptr_v (), w_gate.stride_bytes (0 ), N * 2 , K, config);
387387 } else {
388- // Normal case : VariadicSplit output[1] connects to up
388+ // COMBINED_GATE_UP : VariadicSplit output[0] connects to gate, output[1] connects to up
389389 gate_up.setup (w_gate.ptr_v (), w_gate.ptr_v (N, 0 ), w_gate.stride_bytes (0 ), N * 2 , K, config);
390390 }
391391 } else {
@@ -407,8 +407,7 @@ struct LLMMLP::Executor : public LLMMLP::ExecutorBase {
407407 auto * scale_first = w_scale_gate;
408408 auto * scale_second = w_scale_up;
409409 if (m_config.gate_up_type == LLMMLPNode::GATE_UP_TYPE::COMBINED_UP_GATE) {
410- scale_first = w_scale_up;
411- scale_second = w_scale_gate;
410+ std::swap (scale_first, scale_second);
412411 }
413412 for (size_t i = 0 ; i < N; i += 16 ) {
414413 memcpy (dst, scale_first + i, 16 * sizeof (float ));
0 commit comments