diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 6d1d191689466..14bddaf324ae7 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -118,13 +118,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr); - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + // Transpose from N x K (weights) to K x N. + // DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically) + auto transposed_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* transposed_float = transposed_float_buffer.get(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + transposed_float[c * rows + r] = temp_float[r * cols + c]; + } + } + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); } mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), transposed_float, static_cast(rows), static_cast(cols), static_cast(rows)); return Status::OK(); } @@ -634,6 +644,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); for (int64_t expert_idx : expert_batch) { + bool fc2_bias_added_by_mlas = false; const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; @@ -711,8 +722,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { bool use_direct_q4_gemm = (fc1_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, q_type); - bool fc1_used_direct_q4 = false; - bool fc1_bias_handled_by_q4_gemm = false; if (use_direct_q4_gemm) { IAllocatorUniquePtr mlas_packed_fc1; @@ -750,7 +759,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); if (gemm_status.IsOK()) { - fc1_used_direct_q4 = true; goto fc1_gemm_done; } } @@ -797,8 +805,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; - if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); @@ -891,7 +898,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) && CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, q_type2); - bool fc2_used_direct_q4 = false; if (use_direct_q4_gemm_fc2) { IAllocatorUniquePtr mlas_packed_fc2; @@ -929,7 +935,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, hidden_size, inter_size, q_type2, tp); if (gemm_status.IsOK()) { - fc2_used_direct_q4 = true; + fc2_bias_added_by_mlas = true; goto fc2_gemm_done; } } @@ -979,8 +985,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_gemm_done: - bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); @@ -1015,7 +1020,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 69f0435615079..d60e5b0164fe8 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -57,10 +57,10 @@ MlasQ4GemmPackBSize( * * @param QType type of block quantization * @param PackedBuf destination buffer - * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B + * @param FpData the pointer to fp32 matrix, with shape [K, N]. + * @param N the number of columns of matrix B (Output Channels). + * @param K the number of rows of matrix B (Input Channels). + * @param ldb leading dimension of FpData (usually N) */ void MLASCALL diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 90ebb148a26a5..238ac4d1f077d 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -364,7 +364,7 @@ def create_cpu_moe_onnx_graph( use_swiglu=False, use_quant=False, quant_bits=4, - swiglu_interleaved=False, + swiglu_fusion=0, block_size=0, ): if not has_onnx: @@ -400,10 +400,10 @@ def create_cpu_moe_onnx_graph( "router_probs", # 1 "fc1_experts_weights", # 2 "fc1_scales", # 3 - "", # 4: fc1_bias + "fc1_experts_bias" if fc1_bias is not None else "", # 4 "fc2_experts_weights", # 5 "fc2_scales", # 6 - "", # 7: fc2_bias + "fc2_experts_bias" if fc2_bias is not None else "", # 7 "", # 8: fc3_weights "", # 9: fc3_scales "", # 10: fc3_bias @@ -442,11 +442,10 @@ def create_cpu_moe_onnx_graph( normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation + swiglu_fusion=swiglu_fusion, swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, - swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -559,6 +558,30 @@ def create_cpu_moe_onnx_graph( ) ) + if fc1_bias is not None: + fc1_bias_np = fc1_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc1_experts_bias", + onnx_dtype, + list(fc1_bias.shape), + fc1_bias_np.flatten().tolist(), + raw=False, + ) + ) + + if fc2_bias is not None: + fc2_bias_np = fc2_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc2_experts_bias", + onnx_dtype, + list(fc2_bias.shape), + fc2_bias_np.flatten().tolist(), + raw=False, + ) + ) + graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] @@ -626,7 +649,7 @@ def __init__( self.num_experts_per_token = num_experts_per_token -def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): +def swiglu(x: torch.Tensor, alpha: float = 1.702, beta: float = 1.0, limit: float = 7.0): dim = x.shape[-1] x = x.view(-1, dim // 2, 2) x_glu, x_linear = x[..., 0], x[..., 1] @@ -635,8 +658,8 @@ def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) - y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) - return y + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + beta) + return y.view(-1, dim // 2) class MoEBlockSparseTop2MLP(nn.Module): @@ -855,7 +878,7 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False e = time.time() time_ms = (e - s) / repeat * 1000 is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = hasattr(self, "swiglu_fusion") and self.swiglu_fusion == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") @@ -868,62 +891,80 @@ def recreate_onnx_model(self): """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" w1_list, w2_list = [], [] + w1_bias_list, w2_bias_list = [], [] w1_scale_list, w2_scale_list = [], [] w1_zp_list, w2_zp_list = [], [] is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - if self.block_size > 0: - # Use block-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( - self.experts[i].w1.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( - self.experts[i].w2.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + if hasattr(self.experts[i], "w3"): + w1, w3 = self.experts[i].w1.weight, self.experts[i].w3.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias + w3_bias = getattr(self.experts[i].w3, "bias", None) + + # Combine and interleave w1 and w3 for the fused kernel + w1_combined = torch.cat([w1, w3], dim=0) # [2*inter, hidden] + if getattr(self, "swiglu_fusion", 0) == 1: + w1_combined = w1_combined.view(2, -1, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1_combined, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + else: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1_combined, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + + if w1_bias is not None and w3_bias is not None: + b1_combined = torch.cat([w1_bias, w3_bias], dim=0) + if getattr(self, "swiglu_fusion", 0) == 1: + b1_combined = b1_combined.view(2, -1).transpose(0, 1).reshape(-1) + w1_bias_list.append(b1_combined.detach().cpu()) + elif w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) else: - # Use row-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( - self.experts[i].w1.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( - self.experts[i].w2.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + # PhiMoESwiGLUMLP already has interleaved weights in w1 + w1 = self.experts[i].w1.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias - if self.use_swiglu: - if self.swiglu_interleaved: - pass + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) else: - if self.block_size > 0: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant_blockwise( - self.experts[i].w3.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - else: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant( - self.experts[i].w3.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - - gate_weights = pre_qweight1 - value_weights = pre_qweight3 - gate_scales = w1_scale - value_scales = w3_scale - gate_zp = w1_zp - value_zp = w3_zp - - pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) - w1_scale = torch.cat([gate_scales, value_scales], dim=0) - if w1_zp is not None and w3_zp is not None: - w1_zp = torch.cat([gate_zp, value_zp], dim=0) - - if self.swiglu_interleaved: - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + if w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) + if self.use_swiglu: + if getattr(self, "swiglu_fusion", 0) == 1: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) else: intermediate_size = self.experts[i].w1.weight.shape[0] gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() value_dequant = w1_qdq[intermediate_size:].contiguous().clone() - self.experts[i].w1.weight.data = gate_dequant - self.experts[i].w3.weight.data = value_dequant + if hasattr(self.experts[i], "w3"): + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -931,6 +972,9 @@ def recreate_onnx_model(self): w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) + + if self.experts[i].w2.bias is not None: + w2_bias_list.append(self.experts[i].w2.bias) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) if w1_zp is not None: @@ -963,9 +1007,9 @@ def recreate_onnx_model(self): onnx_dtype=self.onnx_dtype, fc1_experts_weights=self.moe_experts_weight1, fc2_experts_weights=self.moe_experts_weight2, - # Biases are not used in QMoE - fc1_bias=None, - fc2_bias=None, + # Pass collected biases + fc1_bias=torch.stack(w1_bias_list, dim=0) if w1_bias_list else None, + fc2_bias=torch.stack(w2_bias_list, dim=0) if w2_bias_list else None, # Scales are used for dequantization fc1_scales=moe_experts_weight_scale1, fc2_scales=moe_experts_weight_scale2, @@ -975,7 +1019,7 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: @@ -1020,7 +1064,7 @@ def parity_check(self): max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = getattr(self, "swiglu_fusion", 0) == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" quant_type = "Asymmetric" if self.use_asymmetric_quant else "Symmetric" block_type = f"Block({self.block_size})" if self.block_size > 0 else "Row" @@ -1047,24 +1091,6 @@ def parity_check(self): ) print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) - # Print routing and per-expert contributions for this token from the PyTorch reference - try: - hidden_states_flat = hidden_state.view(-1, hidden_dim) - token_vec = hidden_states_flat[i : i + 1] - gate_logits = self.gate(token_vec) - topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) - topk_soft = F.softmax(topk_vals, dim=1) - print("Gate logits:", gate_logits.detach().cpu().numpy()) - print("Selected experts:", topk_experts.detach().cpu().numpy()) - print("Routing weights:", topk_soft.detach().cpu().numpy()) - # Compute per-expert contributions for selected experts - for idx_e, e in enumerate(topk_experts[0].tolist()): - expert_layer = self.experts[e] - expert_out = expert_layer(token_vec) - contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() - print(f"Expert {e} contrib at hidden {k}: {contrib}") - except Exception as _: - pass ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), @@ -1128,7 +1154,7 @@ def __init__( self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_token self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1232,7 +1258,7 @@ def __init__( self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1314,7 +1340,8 @@ def __init__( use_swiglu=self.use_swiglu, use_quant=use_quant, quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved, + # swiglu_fusion=1 means fused and interleaved, which is the standard for QMoE. + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, )