diff --git a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template index 1b74862515c69..6a66d2eb402e5 100644 --- a/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/quantization/matmul_nbits.wgsl.template @@ -84,6 +84,21 @@ $MAIN { #if n_bits == 4 var sum = output_element_t(0); var a_offset = idx * (8 / component_a) * component_b; +#if component_b == 1 + let b_value_lower = vec4(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4(zero); + let b_value_upper = vec4(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4(zero); + let b0 = vec4(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b; + let b1 = vec4(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b; +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1); +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) + + dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1); +#elif component_a == 4 + sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1); +#endif +#else for (var i = 0u; i < component_b; i++) { let b_value_lower = vec4(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4(zero); let b_value_upper = vec4(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4(zero); @@ -102,25 +117,63 @@ $MAIN { a_offset += 2; #endif } +#endif #elif n_bits == 8 var sum = output_element_t(0); var a_offset = idx * (4 / component_a) * component_b; +#if component_b == 1 + let b_value_unpacked = (vec4(unpack4xU8(b_value)) - vec4(zero)) * scale_b; +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked); +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked); +#elif component_a == 4 + sum += dot(tile_A[a_offset], b_value_unpacked); +#endif +#else for (var i = 0u; i < component_b; i++) { - let b_value = (vec4(unpack4xU8(b_value[i])) - vec4(zero)) * scale_b; + let b_value_unpacked = (vec4(unpack4xU8(b_value[i])) - vec4(zero)) * scale_b; #if component_a == 1 - sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value); + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked); a_offset += 4; #elif component_a == 2 - sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value); + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked); a_offset += 2; #elif component_a == 4 - sum += dot(tile_A[a_offset], b_value); + sum += dot(tile_A[a_offset], b_value_unpacked); a_offset += 1; #endif } +#endif #elif n_bits == 2 var sum = output_element_t(0); var a_offset = idx * (16 / component_a) * component_b; +#if component_b == 1 + let b_data_0 = vec4(unpack4xU8(b_value & 0x03030303u)) - vec4(zero); + let b_data_1 = vec4(unpack4xU8((b_value >> 2) & 0x03030303u)) - vec4(zero); + let b_data_2 = vec4(unpack4xU8((b_value >> 4) & 0x03030303u)) - vec4(zero); + let b_data_3 = vec4(unpack4xU8((b_value >> 6) & 0x03030303u)) - vec4(zero); + + let b0 = vec4(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b; + let b1 = vec4(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b; + let b2 = vec4(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b; + let b3 = vec4(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b; + +#if component_a == 1 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) + + dot(vec4(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) + + dot(vec4(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3); +#elif component_a == 2 + sum += dot(vec4(tile_A[a_offset], tile_A[a_offset + 1]), b0) + + dot(vec4(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) + + dot(vec4(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) + + dot(vec4(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3); +#elif component_a == 4 + sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) + + dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3); +#endif +#else for (var i = 0u; i < component_b; i++) { let b_data_0 = vec4(unpack4xU8(b_value[i] & 0x03030303u)) - vec4(zero); let b_data_1 = vec4(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4(zero); @@ -150,6 +203,7 @@ $MAIN { a_offset += 4; #endif } +#endif #endif inter_results[local_row_offset + idy][idx] += sum; diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index 04a4c95dd478b..cfdce9479843c 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -197,10 +197,12 @@ void RunTest2Bits(const TestOptions2Bits& opts) { std::vector> execution_providers; if constexpr (std::is_same::value) { - execution_providers.emplace_back(DefaultCpuExecutionProvider()); #ifdef USE_WEBGPU - execution_providers.push_back(DefaultWebGpuExecutionProvider()); + if (!opts.has_zero_point) { + execution_providers.push_back(DefaultWebGpuExecutionProvider()); + } #endif + execution_providers.emplace_back(DefaultCpuExecutionProvider()); test.ConfigEps(std::move(execution_providers)); test.RunWithConfig(); }