You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[webgpu] make DP4AMatMulNBitsSmallMProgram shader template (#25025)
### Description
This commit refactors the `DP4AMatMulNBitsSmallMProgram` to allow both
`tile_size_k_vec` and `tile_size` to be configured. This change allows
more flexibility for performance tuning without altering the core shader
functionality.
There is no functional change in this commit.
### Motivation and Context
This is a preparatory change to enable `DP4AMatMulNBitsSmallMProgram`
performance optimization work in subsequent commits.
---------
Co-authored-by: Jiajia Qin <jiajiaqin@microsoft.com>
ORT_ENFORCE(WorkgroupSizeX() % tile_size_k_vec_ == 0 && tile_size_k_vec_ % 4 == 0, "tile_size_k_vec_ must evenly divide workgroup size X and be divisible by 4");
ORT_ENFORCE(tile_size_ % sub_tile_count == 0, "tile_size_ must be divisible by sub_tile_count");
332
+
328
333
// This algorithm works to compute dot product of k parallelly, by processing k at each step amongst tile_size_k_vec threads,
329
334
// and utilizing the remaining threads in the workgroup to process additional rows of b in parallel (such that the values in shared memory for A can be reused).
330
335
// For each load of k, the tile_size_k_vec threads also reload B tile_size/num_concurrent_b_rows times to compute partial dot products of other B rows
331
336
// in order to complete all tile_size b rows in this workgroup and also reusing the loaded in register values of a.
332
337
333
-
// 1. Each workgroup handles tile_size_k_vec (16) * k_vectorization_in_b (32) columns (total 512) and num_concurrent_b_rows of matrix B at a time,
338
+
// 1. Each workgroup handles tile_size_k_vec * k_vectorization_in_b (32) columns and num_concurrent_b_rows of matrix B at a time,
334
339
// iterating over the columns to compute a partial dot product.
335
340
// 2. Uses vec4 vectorization where each K represents 32 elements of matrix B
336
-
constexpruint32_t tile_size_k_vec = 16;
337
341
338
342
// 1. Workgroup Responsibility:
339
343
// - Processes one row of matrix A
@@ -346,18 +350,19 @@ Status DP4AMatMulNBitsSmallMProgram::GenerateShaderCode(ShaderHelper& shader) co
346
350
// - Iterates through columns accumulating results in inter_results
347
351
// - Performs final reduction sum in inter_results for output
0 commit comments