@@ -84,6 +84,21 @@ $MAIN {
8484#if n_bits == 4
8585 var sum = output_element_t(0);
8686 var a_offset = idx * (8 / component_a) * component_b;
87+ #if component_b == 1
88+ let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
89+ let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
90+ let b0 = vec4<output_element_t>(b_value_lower[0], b_value_upper[0], b_value_lower[1], b_value_upper[1]) * scale_b;
91+ let b1 = vec4<output_element_t>(b_value_lower[2], b_value_upper[2], b_value_lower[3], b_value_upper[3]) * scale_b;
92+ #if component_a == 1
93+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
94+ dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1);
95+ #elif component_a == 2
96+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
97+ dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1);
98+ #elif component_a == 4
99+ sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1);
100+ #endif
101+ #else
87102 for (var i = 0u; i < component_b; i++) {
88103 let b_value_lower = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
89104 let b_value_upper = vec4<output_element_t>(unpack4xU8((b_value[i] >> 4) & 0x0F0F0F0Fu)) - vec4<output_element_t>(zero);
@@ -102,25 +117,63 @@ $MAIN {
102117 a_offset += 2;
103118#endif
104119 }
120+ #endif
105121#elif n_bits == 8
106122 var sum = output_element_t(0);
107123 var a_offset = idx * (4 / component_a) * component_b;
124+ #if component_b == 1
125+ let b_value_unpacked = (vec4<output_element_t>(unpack4xU8(b_value)) - vec4<output_element_t>(zero)) * scale_b;
126+ #if component_a == 1
127+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked);
128+ #elif component_a == 2
129+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked);
130+ #elif component_a == 4
131+ sum += dot(tile_A[a_offset], b_value_unpacked);
132+ #endif
133+ #else
108134 for (var i = 0u; i < component_b; i++) {
109- let b_value = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
135+ let b_value_unpacked = (vec4<output_element_t>(unpack4xU8(b_value[i])) - vec4<output_element_t>(zero)) * scale_b;
110136#if component_a == 1
111- sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value );
137+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b_value_unpacked );
112138 a_offset += 4;
113139#elif component_a == 2
114- sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value );
140+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b_value_unpacked );
115141 a_offset += 2;
116142#elif component_a == 4
117- sum += dot(tile_A[a_offset], b_value );
143+ sum += dot(tile_A[a_offset], b_value_unpacked );
118144 a_offset += 1;
119145#endif
120146 }
147+ #endif
121148#elif n_bits == 2
122149 var sum = output_element_t(0);
123150 var a_offset = idx * (16 / component_a) * component_b;
151+ #if component_b == 1
152+ let b_data_0 = vec4<output_element_t>(unpack4xU8(b_value & 0x03030303u)) - vec4<output_element_t>(zero);
153+ let b_data_1 = vec4<output_element_t>(unpack4xU8((b_value >> 2) & 0x03030303u)) - vec4<output_element_t>(zero);
154+ let b_data_2 = vec4<output_element_t>(unpack4xU8((b_value >> 4) & 0x03030303u)) - vec4<output_element_t>(zero);
155+ let b_data_3 = vec4<output_element_t>(unpack4xU8((b_value >> 6) & 0x03030303u)) - vec4<output_element_t>(zero);
156+
157+ let b0 = vec4<output_element_t>(b_data_0[0], b_data_1[0], b_data_2[0], b_data_3[0]) * scale_b;
158+ let b1 = vec4<output_element_t>(b_data_0[1], b_data_1[1], b_data_2[1], b_data_3[1]) * scale_b;
159+ let b2 = vec4<output_element_t>(b_data_0[2], b_data_1[2], b_data_2[2], b_data_3[2]) * scale_b;
160+ let b3 = vec4<output_element_t>(b_data_0[3], b_data_1[3], b_data_2[3], b_data_3[3]) * scale_b;
161+
162+ #if component_a == 1
163+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1], tile_A[a_offset + 2], tile_A[a_offset + 3]), b0) +
164+ dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5], tile_A[a_offset + 6], tile_A[a_offset + 7]), b1) +
165+ dot(vec4<output_element_t>(tile_A[a_offset + 8], tile_A[a_offset + 9], tile_A[a_offset + 10], tile_A[a_offset + 11]), b2) +
166+ dot(vec4<output_element_t>(tile_A[a_offset + 12], tile_A[a_offset + 13], tile_A[a_offset + 14], tile_A[a_offset + 15]), b3);
167+ #elif component_a == 2
168+ sum += dot(vec4<output_element_t>(tile_A[a_offset], tile_A[a_offset + 1]), b0) +
169+ dot(vec4<output_element_t>(tile_A[a_offset + 2], tile_A[a_offset + 3]), b1) +
170+ dot(vec4<output_element_t>(tile_A[a_offset + 4], tile_A[a_offset + 5]), b2) +
171+ dot(vec4<output_element_t>(tile_A[a_offset + 6], tile_A[a_offset + 7]), b3);
172+ #elif component_a == 4
173+ sum += dot(tile_A[a_offset], b0) + dot(tile_A[a_offset + 1], b1) +
174+ dot(tile_A[a_offset + 2], b2) + dot(tile_A[a_offset + 3], b3);
175+ #endif
176+ #else
124177 for (var i = 0u; i < component_b; i++) {
125178 let b_data_0 = vec4<output_element_t>(unpack4xU8(b_value[i] & 0x03030303u)) - vec4<output_element_t>(zero);
126179 let b_data_1 = vec4<output_element_t>(unpack4xU8((b_value[i] >> 2) & 0x03030303u)) - vec4<output_element_t>(zero);
@@ -150,6 +203,7 @@ $MAIN {
150203 a_offset += 4;
151204#endif
152205 }
206+ #endif
153207#endif
154208
155209 inter_results[local_row_offset + idy][idx] += sum;
0 commit comments