@@ -49,7 +49,7 @@ void HandleMaybeBiasForMatMul(ShaderHelper& shader,
4949 shader.AdditionalImplementation () << " value = value + output_value_t(" << (is_channels_last ? bias->GetByOffset (" colIn" ) : bias->GetByOffset (" row" )) << " );\n " ;
5050 }
5151 shader.AdditionalImplementation () << " " << activation_snippet << " \n "
52- << output.SetByIndices (" coords" , " value" ) << " \n " ;
52+ << " " << output.SetByIndices (" coords" , " value" ) << " \n " ;
5353}
5454
5555void HandleMatMulWithSplitK (
@@ -127,60 +127,61 @@ void MatMulReadFnSource(ShaderHelper& shader,
127127 const ShaderVariableHelper& b,
128128 const ShaderIndicesHelper* batch_dims,
129129 bool transA,
130- bool transB,
131- bool is_vec4) {
132- int components = is_vec4 ? 4 : 1 ;
130+ bool transB) {
131+ const int a_components = a.NumComponents ();
133132 const std::string data_type = " output_element_t" ;
134- const std::string type_string = MakeScalarOrVectorType (components , data_type);
133+ std::string type_string = MakeScalarOrVectorType (a_components , data_type);
135134
136135 shader.AdditionalImplementation ()
137136 << " fn mm_readA(batch: i32, row: i32, colIn: i32 "
138137 << (batch_dims
139138 ? " , batch_indices: batch_dims_indices_t"
140139 : " " )
141- << " ) -> " << type_string << " {\n "
142- << " var value = " << type_string << " (0);\n "
143- << " let col = colIn * " << components << " ;\n " ;
140+ << " ) -> " << type_string << " {\n "
141+ << " var value = " << type_string << " (0);\n "
142+ << " let col = colIn * " << a_components << " ;\n " ;
144143 if (transA) {
145- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n " ;
144+ shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_a_outer)) {\n " ;
146145 } else {
147- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n " ;
146+ shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_inner)) {\n " ;
148147 }
149- shader.AdditionalImplementation () << " var a_indices: a_indices_t;\n " ;
148+ shader.AdditionalImplementation () << " var a_indices: a_indices_t;\n " ;
150149
151150 if (batch_dims) {
152- shader.AdditionalImplementation () << ConvertOutputBatchIndicesToInputBatchIndices (" a" , a, a.Rank () - 2 , batch_dims ? batch_dims->Rank () : 0 , " batch_indices " ) << " \n " ;
151+ shader.AdditionalImplementation () << ConvertOutputBatchIndicesToInputBatchIndices (" a" , a, a.Rank () - 2 , batch_dims ? batch_dims->Rank () : 0 , " batch_indices " );
153152 }
154- shader.AdditionalImplementation () << a.IndicesSet (" a_indices" , a.Rank () - 2 , " u32(row)" ) << " \n "
155- << a.IndicesSet (" a_indices" , a.Rank () - 1 , " u32(colIn)" ) << " \n "
156- << " value = " << a.GetByIndices (" a_indices" ) << " ;\n "
157- << " }\n "
158- << " return value;\n "
153+ shader.AdditionalImplementation () << " " << a.IndicesSet (" a_indices" , a.Rank () - 2 , " u32(row)" ) << " \n "
154+ << " " << a.IndicesSet (" a_indices" , a.Rank () - 1 , " u32(colIn)" ) << " \n "
155+ << " value = " << a.GetByIndices (" a_indices" ) << " ;\n "
156+ << " }\n "
157+ << " return value;\n "
159158 << " }\n\n " ;
160159
161160 // Add the mm_readB function
161+ const int b_components = b.NumComponents ();
162+ type_string = MakeScalarOrVectorType (b_components, data_type);
162163 shader.AdditionalImplementation ()
163164 << " fn mm_readB(batch: i32, row: i32, colIn: i32 "
164165 << (batch_dims
165166 ? " , batch_indices: batch_dims_indices_t"
166167 : " " )
167- << " ) -> " << type_string << " {\n "
168- << " var value = " << type_string << " (0);\n "
169- << " let col = colIn * " << components << " ;\n " ;
168+ << " ) -> " << type_string << " {\n "
169+ << " var value = " << type_string << " (0);\n "
170+ << " let col = colIn * " << b_components << " ;\n " ;
170171
171172 if (transB) {
172- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n " ;
173+ shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_b_outer) && col < i32(uniforms.dim_inner)) {\n " ;
173174 } else {
174- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n " ;
175+ shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_inner) && col < i32(uniforms.dim_b_outer)) {\n " ;
175176 }
176177
177- shader.AdditionalImplementation () << " var b_indices: b_indices_t;\n "
178+ shader.AdditionalImplementation () << " var b_indices: b_indices_t;\n "
178179 << ConvertOutputBatchIndicesToInputBatchIndices (" b" , b, b.Rank () - 2 , batch_dims ? batch_dims->Rank () : 0 , " batch_indices" )
179- << b.IndicesSet (" b_indices" , b.Rank () - 2 , " u32(row)" ) << " \n "
180- << b.IndicesSet (" b_indices" , b.Rank () - 1 , " u32(colIn)" ) << " \n "
181- << " value = " << b.GetByIndices (" b_indices" ) << " ;\n "
182- << " }\n "
183- << " return value;\n "
180+ << " " << b.IndicesSet (" b_indices" , b.Rank () - 2 , " u32(row)" ) << " \n "
181+ << " " << b.IndicesSet (" b_indices" , b.Rank () - 1 , " u32(colIn)" ) << " \n "
182+ << " value = " << b.GetByIndices (" b_indices" ) << " ;\n "
183+ << " }\n "
184+ << " return value;\n "
184185 << " }\n\n " ;
185186}
186187
@@ -189,19 +190,19 @@ void MatMulWriteFnSource(ShaderHelper& shader,
189190 const ShaderVariableHelper* bias,
190191 bool is_gemm,
191192 int c_components,
192- int output_components,
193193 bool c_is_scalar,
194194 std::string activation_snippet,
195195 bool is_channels_last,
196196 bool use_split_k,
197197 ProgramVariableDataType output_variable_type) {
198+ const int output_components = output.NumComponents ();
198199 shader.AdditionalImplementation ()
199- << " fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) { \n " ;
200+ << " fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: output_value_t) {\n " ;
200201
201202 shader.AdditionalImplementation () << " let col = colIn * " << output_components << " ;\n " ;
202203
203- shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) { \n "
204- << " var value = valueIn; \n " ;
204+ shader.AdditionalImplementation () << " if(row < i32(uniforms.dim_a_outer) && col < i32(uniforms.dim_b_outer)) {\n "
205+ << " var value = valueIn;\n " ;
205206
206207 if (use_split_k) {
207208 // Set output when MatMul is performed with Split-K.
0 commit comments