@@ -83,6 +83,7 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
8383
8484Status ApplyDP4AMatrixMatMulNBits (const Tensor* a, const Tensor* b, const Tensor* scales,
8585 const Tensor* zero_points, const Tensor* bias,
86+ uint32_t batch_count,
8687 uint32_t M,
8788 uint32_t N,
8889 uint32_t K,
@@ -101,15 +102,15 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
101102 DP4AMatMulQuantizeProgram quantize_program;
102103 quantize_program.SetWorkgroupSize (64 );
103104 uint32_t tile_size = 64 * kVec4Components ;
104- quantize_program.SetDispatchGroupSize ((M * K + tile_size - 1 ) / tile_size, 1 , 1 );
105- TensorShape a_quant_shape{1 , M, K / kU32Components };
105+ quantize_program.SetDispatchGroupSize ((batch_count * M * K + tile_size - 1 ) / tile_size, 1 , 1 );
106+ TensorShape a_quant_shape{batch_count , M, K / kU32Components };
106107 Tensor a_quant = context.CreateGPUTensor (DataTypeImpl::GetType<uint32_t >(), a_quant_shape);
107- TensorShapeVector a_scales_dims ({1 , 1 , M, K / kBlockSizeA });
108+ TensorShapeVector a_scales_dims ({batch_count , 1 , M, K / kBlockSizeA });
108109 Tensor a_scale = context.CreateGPUTensor (a->DataType (), a_scales_dims);
109110 quantize_program.AddInputs ({{a, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec4Components )}})
110111 .AddOutputs ({{&a_quant, ProgramTensorMetadataDependency::Rank, a_quant.Shape (), 1 },
111112 {&a_scale, ProgramTensorMetadataDependency::Rank, 1 }})
112- .AddUniformVariable ({M * K / kU32Components });
113+ .AddUniformVariable ({batch_count * M * K / kU32Components });
113114 ORT_RETURN_IF_ERROR (context.RunProgram (quantize_program));
114115
115116 const bool has_zero_points = zero_points != nullptr ;
@@ -128,12 +129,12 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
128129 DP4AMatMulNBitsSmallMProgram mul_program{tile_size_k_vec, tile_size_n, nbits, has_zero_points, has_bias, has_weight_idx, single_scale_weights};
129130 uint32_t num_N_tile = (N + tile_size_n - 1 ) / tile_size_n;
130131 mul_program.SetWorkgroupSize (128 );
131- mul_program.SetDispatchGroupSize (M * num_N_tile);
132+ mul_program.SetDispatchGroupSize (batch_count * M * num_N_tile);
132133 mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec4Components )},
133134 {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1 },
134135 {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(b_components * kU32Components )},
135136 {scales, ProgramTensorMetadataDependency::TypeAndRank, 1 }})
136- .AddUniformVariables ({M, N, K, K / 16 , K / 32 , block_size, num_N_tile, zero_blocks_per_col, weight_index})
137+ .AddUniformVariables ({batch_count, M, N, K, K / 16 , K / 32 , block_size, num_N_tile, zero_blocks_per_col, weight_index})
137138 .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, 1 })
138139 .CacheHint (nbits, tile_size_k_vec, tile_size_n, has_zero_points, single_scale_weights, has_bias, has_weight_idx);
139140 if (has_zero_points) {
@@ -146,22 +147,24 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
146147 }
147148
148149 constexpr uint32_t kTileSize = 64 ;
149- TensorShape reshaped_y_shape{1 , M, N / kVec4Components };
150+ TensorShape reshaped_y_shape{batch_count , M, N / kVec4Components };
150151 uint32_t num_M_tile = (M + kTileSize - 1 ) / kTileSize ;
151152 uint32_t num_N_tile = (N + kTileSize - 1 ) / kTileSize ;
152153 bool is_qualcomm = context.AdapterInfo ().vendor == std::string_view{" qualcomm" };
153154 DP4AMatMulNBitsProgram mul_program{block_size, nbits, has_zero_points, has_bias, has_weight_idx, is_qualcomm};
154155 mul_program.SetWorkgroupSize (256 );
155- mul_program.SetDispatchGroupSize (num_M_tile * num_N_tile);
156+ mul_program.SetDispatchGroupSize (batch_count * num_M_tile * num_N_tile);
156157 mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec4Components )},
157158 {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1 },
158159 {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >((nbits / 2 ) * kU32Components )},
159160 {scales, ProgramTensorMetadataDependency::TypeAndRank, 1 }})
160- .AddUniformVariables ({{static_cast <uint32_t >(M)},
161+ .AddUniformVariables ({{static_cast <uint32_t >(batch_count)},
162+ {static_cast <uint32_t >(M)},
161163 {static_cast <uint32_t >(N)},
162164 {static_cast <uint32_t >(K)},
163165 {static_cast <uint32_t >(K / 8 )},
164166 {static_cast <uint32_t >(K / 16 )},
167+ {num_M_tile},
165168 {num_N_tile},
166169 {zero_blocks_per_col},
167170 {weight_index}})
@@ -179,7 +182,6 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
179182bool CanApplyDP4AMatrixMatMulNBits (onnxruntime::webgpu::ComputeContext& context,
180183 uint64_t accuracy_level,
181184 uint32_t block_size,
182- uint32_t batch_count,
183185 uint32_t N,
184186 uint32_t K,
185187 uint32_t components_k) {
@@ -189,7 +191,7 @@ bool CanApplyDP4AMatrixMatMulNBits(onnxruntime::webgpu::ComputeContext& context,
189191 bool use_dp4a = context.HasFeature (wgpu::FeatureName::Subgroups) &&
190192 context.AdapterInfo ().vendor != std::string_view{" apple" };
191193 return (accuracy_level == 4 && block_size % 32 == 0 &&
192- batch_count == 1 && components_k == 4 && K % 128 == 0 && N % 16 == 0 &&
194+ components_k == 4 && K % 128 == 0 && N % 16 == 0 &&
193195 use_dp4a);
194196}
195197
0 commit comments