Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ $MAIN {
#if n_bits == 4
var sum = output_element_t(0);
var a_offset = idx * (8 / component_a) * component_b;
#if component_b == 1
let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
let b0 = vec4<output_element_t>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b;
let b1 = vec4<output_element_t>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b;
#if component_a == 1
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);
#elif component_a == 2
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);
#elif component_a == 4
sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);
#endif
#else
for (var i = 0u; i < component_b; i++) {
let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
Expand All @@ -102,25 +117,63 @@ $MAIN {
a_offset += 2;
#endif
}
#endif
#elif n_bits == 8
var sum = output_element_t(0);
var a_offset = idx * (4 / component_a) * component_b;
#if component_b == 1
let b_value_unpacked = (vec4<output_element_t>(unpack4xU8(b_value)) - vec4<output_element_t>(zero)) * scale_b;
#if component_a == 1
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked);
#elif component_a == 2
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked);
#elif component_a == 4
sum += dot(tile_A[a_offset], b_value_unpacked);
#endif
#else
for (var i = 0u; i < component_b; i++) {
let b_value = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
let b_value_unpacked = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
#if component_a == 1
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value);
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked);
a_offset += 4;
#elif component_a == 2
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value);
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked);
a_offset += 2;
#elif component_a == 4
sum += dot(tile_A[a_offset], b_value);
sum += dot(tile_A[a_offset], b_value_unpacked);
a_offset += 1;
#endif
}
#endif
#elif n_bits == 2
var sum = output_element_t(0);
var a_offset = idx * (16 / component_a) * component_b;
#if component_b == 1
let b_data_0 = vec4<output_element_t>(unpack4xU8(b_value & 0x03030303u)) - vec4<output_element_t>(zero);
let b_data_1 = vec4<output_element_t>(unpack4xU8((b_value >> 2) & 0x03030303u)) - vec4<output_element_t>(zero);
let b_data_2 = vec4<output_element_t>(unpack4xU8((b_value >> 4) & 0x03030303u)) - vec4<output_element_t>(zero);
let b_data_3 = vec4<output_element_t>(unpack4xU8((b_value >> 6) & 0x03030303u)) - vec4<output_element_t>(zero);

let b0 = vec4<output_element_t>(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b;
let b1 = vec4<output_element_t>(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b;
let b2 = vec4<output_element_t>(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b;
let b3 = vec4<output_element_t>(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b;

#if component_a == 1
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) +
dot(vec4<output_element_t>(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) +
dot(vec4<output_element_t>(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3);
#elif component_a == 2
sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) +
dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) +
dot(vec4<output_element_t>(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3);
#elif component_a == 4
sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) +
dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3);
#endif
#else
for (var i = 0u; i < component_b; i++) {
let b_data_0 = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x03030303u)) - vec4<output_element_t>(zero);
let b_data_1 = vec4<output_element_t>(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4<output_element_t>(zero);
Expand Down Expand Up @@ -150,6 +203,7 @@ $MAIN {
a_offset += 4;
#endif
}
#endif
#endif

inter_results[local_row_offset + idy][idx] += sum;
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/test/contrib_ops/matmul_2bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,12 @@ void RunTest2Bits(const TestOptions2Bits& opts) {

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if constexpr (std::is_same<T1, float>::value) {
execution_providers.emplace_back(DefaultCpuExecutionProvider());
#ifdef USE_WEBGPU
execution_providers.push_back(DefaultWebGpuExecutionProvider());
if (!opts.has_zero_point) {
execution_providers.push_back(DefaultWebGpuExecutionProvider());
}
#endif
execution_providers.emplace_back(DefaultCpuExecutionProvider());
test.ConfigEps(std::move(execution_providers));
test.RunWithConfig();
}
Expand Down
Loading