@@ -911,8 +911,8 @@ struct vk_device_struct {
911911 vk_pipeline pipeline_pool2d_f32;
912912 vk_pipeline pipeline_rwkv_wkv6_f32;
913913 vk_pipeline pipeline_rwkv_wkv7_f32;
914- // [size_idx][kda] where size_idx: 0=d32 , 1=d64 , 2=d128
915- vk_pipeline pipeline_gated_delta_net[3 ][2];
914+ // [size_idx][kda] where size_idx: 0=d16 , 1=d32 , 2=d64, 3 =d128
915+ vk_pipeline pipeline_gated_delta_net[4 ][2];
916916 vk_pipeline pipeline_ssm_scan_f32_d128;
917917 vk_pipeline pipeline_ssm_scan_f32_d256;
918918 vk_pipeline pipeline_ssm_conv_f32;
@@ -5231,14 +5231,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52315231 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
52325232
52335233 {
5234- const uint32_t gdn_sizes[] = {32, 64, 128};
5234+ const uint32_t gdn_sizes[] = {16, 32, 64, 128};
52355235 const char * gdn_names[][2] = {
5236+ {"gated_delta_net_f32_d16", "gated_delta_net_f32_d16_kda"},
52365237 {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
52375238 {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
52385239 {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
52395240 };
5240- const bool use_subgroup_reduce = device->subgroup_arithmetic;
5241- for (uint32_t si = 0; si < 3; si++) {
5241+ for (uint32_t si = 0; si < 4; si++) {
52425242 const uint32_t S_V = gdn_sizes[si];
52435243 GGML_ASSERT(is_pow2(S_V));
52445244
@@ -5252,10 +5252,29 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52525252 lanes_per_column = std::min(S_V, device->subgroup_size);
52535253 }
52545254
5255- const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
5255+ // gated_delta_net.comp relies on S_V % COLS_PER_WG == 0 and
5256+ // S_V % LANES_PER_COLUMN == 0 to avoid bounds checks.
5257+ while (lanes_per_column > 1u) {
5258+ const bool valid_lanes = (device->subgroup_size % lanes_per_column) == 0 &&
5259+ (S_V % lanes_per_column) == 0;
5260+ const uint32_t cols_per_wg = valid_lanes ? device->subgroup_size / lanes_per_column : 0;
5261+ if (valid_lanes && cols_per_wg > 0 && (S_V % cols_per_wg) == 0) {
5262+ break;
5263+ }
5264+ lanes_per_column >>= 1u;
5265+ }
5266+
5267+ GGML_ASSERT((device->subgroup_size % lanes_per_column) == 0);
5268+ GGML_ASSERT((S_V % lanes_per_column) == 0);
5269+ GGML_ASSERT((S_V % (device->subgroup_size / lanes_per_column)) == 0);
5270+
5271+ const bool need_partial_subgroup_reduce = lanes_per_column != 1u && lanes_per_column < device->subgroup_size;
5272+ const bool use_clustered_reduce = device->subgroup_arithmetic && device->subgroup_clustered && need_partial_subgroup_reduce;
5273+ const bool use_subgroup_reduce = device->subgroup_arithmetic && !need_partial_subgroup_reduce;
5274+ const bool use_subgroup_ops = use_clustered_reduce || use_subgroup_reduce;
52565275 size_t gdn_len;
52575276 const void * gdn_data;
5258- if (use_subgroup_reduce && need_clustered_shader ) {
5277+ if (use_clustered_reduce ) {
52595278 gdn_len = gated_delta_net_f32_len;
52605279 gdn_data = (const void *)gated_delta_net_f32_data;
52615280 } else if (use_subgroup_reduce) {
@@ -5272,7 +5291,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52725291 for (uint32_t kda = 0; kda < 2; kda++) {
52735292 ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
52745293 gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
5275- wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce , device->subgroup_size);
5294+ wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_ops , device->subgroup_size);
52765295 }
52775296 }
52785297 }
@@ -10746,9 +10765,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1074610765 const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
1074710766 uint32_t si;
1074810767 switch (S_v) {
10749- case 32: si = 0; break;
10750- case 64: si = 1; break;
10751- case 128: si = 2; break;
10768+ case 16: si = 0; break;
10769+ case 32: si = 1; break;
10770+ case 64: si = 2; break;
10771+ case 128: si = 3; break;
1075210772 default: return nullptr;
1075310773 }
1075410774 return ctx->device->pipeline_gated_delta_net[si][kda];
@@ -17193,7 +17213,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1719317213 case GGML_OP_GATED_DELTA_NET:
1719417214 {
1719517215 const uint32_t S_v = op->src[2]->ne[0];
17196- if (S_v != 32 && S_v != 64 && S_v != 128) {
17216+ if (S_v != 16 && S_v != 32 && S_v != 64 && S_v != 128) {
1719717217 return false;
1719817218 }
1719917219 for (int i = 0; i < 6; i++) {
0 commit comments