Skip to content

Commit d5fb104

Browse files
authored
vulkan: Support gated_delta_net with S_v=16 (#24581)
1 parent 635b65a commit d5fb104

1 file changed

Lines changed: 32 additions & 12 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -911,8 +911,8 @@ struct vk_device_struct {
911911
vk_pipeline pipeline_pool2d_f32;
912912
vk_pipeline pipeline_rwkv_wkv6_f32;
913913
vk_pipeline pipeline_rwkv_wkv7_f32;
914-
// [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
915-
vk_pipeline pipeline_gated_delta_net[3][2];
914+
// [size_idx][kda] where size_idx: 0=d16, 1=d32, 2=d64, 3=d128
915+
vk_pipeline pipeline_gated_delta_net[4][2];
916916
vk_pipeline pipeline_ssm_scan_f32_d128;
917917
vk_pipeline pipeline_ssm_scan_f32_d256;
918918
vk_pipeline pipeline_ssm_conv_f32;
@@ -5231,14 +5231,14 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52315231
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
52325232

52335233
{
5234-
const uint32_t gdn_sizes[] = {32, 64, 128};
5234+
const uint32_t gdn_sizes[] = {16, 32, 64, 128};
52355235
const char * gdn_names[][2] = {
5236+
{"gated_delta_net_f32_d16", "gated_delta_net_f32_d16_kda"},
52365237
{"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
52375238
{"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
52385239
{"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
52395240
};
5240-
const bool use_subgroup_reduce = device->subgroup_arithmetic;
5241-
for (uint32_t si = 0; si < 3; si++) {
5241+
for (uint32_t si = 0; si < 4; si++) {
52425242
const uint32_t S_V = gdn_sizes[si];
52435243
GGML_ASSERT(is_pow2(S_V));
52445244

@@ -5252,10 +5252,29 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52525252
lanes_per_column = std::min(S_V, device->subgroup_size);
52535253
}
52545254

5255-
const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
5255+
// gated_delta_net.comp relies on S_V % COLS_PER_WG == 0 and
5256+
// S_V % LANES_PER_COLUMN == 0 to avoid bounds checks.
5257+
while (lanes_per_column > 1u) {
5258+
const bool valid_lanes = (device->subgroup_size % lanes_per_column) == 0 &&
5259+
(S_V % lanes_per_column) == 0;
5260+
const uint32_t cols_per_wg = valid_lanes ? device->subgroup_size / lanes_per_column : 0;
5261+
if (valid_lanes && cols_per_wg > 0 && (S_V % cols_per_wg) == 0) {
5262+
break;
5263+
}
5264+
lanes_per_column >>= 1u;
5265+
}
5266+
5267+
GGML_ASSERT((device->subgroup_size % lanes_per_column) == 0);
5268+
GGML_ASSERT((S_V % lanes_per_column) == 0);
5269+
GGML_ASSERT((S_V % (device->subgroup_size / lanes_per_column)) == 0);
5270+
5271+
const bool need_partial_subgroup_reduce = lanes_per_column != 1u && lanes_per_column < device->subgroup_size;
5272+
const bool use_clustered_reduce = device->subgroup_arithmetic && device->subgroup_clustered && need_partial_subgroup_reduce;
5273+
const bool use_subgroup_reduce = device->subgroup_arithmetic && !need_partial_subgroup_reduce;
5274+
const bool use_subgroup_ops = use_clustered_reduce || use_subgroup_reduce;
52565275
size_t gdn_len;
52575276
const void * gdn_data;
5258-
if (use_subgroup_reduce && need_clustered_shader) {
5277+
if (use_clustered_reduce) {
52595278
gdn_len = gated_delta_net_f32_len;
52605279
gdn_data = (const void *)gated_delta_net_f32_data;
52615280
} else if (use_subgroup_reduce) {
@@ -5272,7 +5291,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52725291
for (uint32_t kda = 0; kda < 2; kda++) {
52735292
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
52745293
gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
5275-
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
5294+
wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_ops, device->subgroup_size);
52765295
}
52775296
}
52785297
}
@@ -10746,9 +10765,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1074610765
const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
1074710766
uint32_t si;
1074810767
switch (S_v) {
10749-
case 32: si = 0; break;
10750-
case 64: si = 1; break;
10751-
case 128: si = 2; break;
10768+
case 16: si = 0; break;
10769+
case 32: si = 1; break;
10770+
case 64: si = 2; break;
10771+
case 128: si = 3; break;
1075210772
default: return nullptr;
1075310773
}
1075410774
return ctx->device->pipeline_gated_delta_net[si][kda];
@@ -17193,7 +17213,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1719317213
case GGML_OP_GATED_DELTA_NET:
1719417214
{
1719517215
const uint32_t S_v = op->src[2]->ne[0];
17196-
if (S_v != 32 && S_v != 64 && S_v != 128) {
17216+
if (S_v != 16 && S_v != 32 && S_v != 64 && S_v != 128) {
1719717217
return false;
1719817218
}
1719917219
for (int i = 0; i < 6; i++) {

0 commit comments

Comments
 (0)