Skip to content

Commit ad39cca

Browse files
vulkan: add col2im_1d op (#24425)
* vulkan: add GGML_OP_COL2IM_1D, follow-up to the CPU op * vulkan: col2im_1d bounded gather loop instead of full-K scan with modulo * vulkan: col2im_1d address review from @jeffbolznv * vulkan: col2im_1d return nullptr for unsupported types, address review from @0cc4m
1 parent 7dad2f1 commit ad39cca

3 files changed

Lines changed: 133 additions & 0 deletions

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,9 @@ struct vk_device_struct {
902902
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
903903
vk_pipeline pipeline_timestep_embedding_f32;
904904
vk_pipeline pipeline_conv_transpose_1d_f32;
905+
vk_pipeline pipeline_col2im_1d_f32;
906+
vk_pipeline pipeline_col2im_1d_f16;
907+
vk_pipeline pipeline_col2im_1d_bf16;
905908
vk_pipeline pipeline_snake_f32;
906909
vk_pipeline pipeline_snake_f16;
907910
vk_pipeline pipeline_snake_bf16;
@@ -1552,6 +1555,16 @@ struct vk_op_timestep_embedding_push_constants {
15521555
uint32_t max_period;
15531556
};
15541557

1558+
struct vk_op_col2im_1d_push_constants {
1559+
uint32_t T_out;
1560+
uint32_t OC;
1561+
uint32_t K_OC;
1562+
uint32_t T_in;
1563+
uint32_t K;
1564+
int32_t stride;
1565+
int32_t p0;
1566+
};
1567+
15551568
struct vk_op_conv_transpose_1d_push_constants {
15561569
uint32_t Cout;
15571570
uint32_t Cin;
@@ -5203,6 +5216,9 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52035216
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
52045217

52055218
ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
5219+
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f32, "col2im_1d_f32", col2im_1d_f32_len, col2im_1d_f32_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
5220+
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f16, "col2im_1d_f16", col2im_1d_f16_len, col2im_1d_f16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
5221+
ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_bf16, "col2im_1d_bf16", col2im_1d_bf16_len, col2im_1d_bf16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
52065222

52075223
ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
52085224
ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
@@ -10702,6 +10718,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1070210718
return ctx->device->pipeline_conv_transpose_1d_f32;
1070310719
}
1070410720
return nullptr;
10721+
case GGML_OP_COL2IM_1D:
10722+
switch (src0->type) {
10723+
case GGML_TYPE_F32: return ctx->device->pipeline_col2im_1d_f32;
10724+
case GGML_TYPE_F16: return ctx->device->pipeline_col2im_1d_f16;
10725+
case GGML_TYPE_BF16: return ctx->device->pipeline_col2im_1d_bf16;
10726+
default: return nullptr;
10727+
}
1070510728
case GGML_OP_POOL_2D:
1070610729
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1070710730
return ctx->device->pipeline_pool2d_f32;
@@ -11147,6 +11170,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1114711170
{
1114811171
elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
1114911172
} break;
11173+
case GGML_OP_COL2IM_1D:
11174+
{
11175+
elements = { uint32_t(dst->ne[0]), uint32_t(dst->ne[1]), 1 };
11176+
} break;
1115011177
case GGML_OP_POOL_2D:
1115111178
{
1115211179
const uint32_t N = dst->ne[3];
@@ -12936,6 +12963,32 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
1293612963
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
1293712964
}
1293812965

12966+
static void ggml_vk_col2im_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
12967+
// src0: [K_OC, T_in] columns from matmul
12968+
// dst: [T_out, OC]
12969+
12970+
const int32_t stride = dst->op_params[0];
12971+
const int32_t oc = dst->op_params[1];
12972+
const int32_t p0 = dst->op_params[2];
12973+
12974+
const uint32_t K_OC = static_cast<uint32_t>(src0->ne[0]);
12975+
const uint32_t T_in = static_cast<uint32_t>(src0->ne[1]);
12976+
const uint32_t T_out = static_cast<uint32_t>(dst->ne[0]);
12977+
const uint32_t OC = static_cast<uint32_t>(oc);
12978+
const uint32_t K = K_OC / OC;
12979+
12980+
vk_op_col2im_1d_push_constants p{};
12981+
p.T_out = T_out;
12982+
p.OC = OC;
12983+
p.K_OC = K_OC;
12984+
p.T_in = T_in;
12985+
p.K = K;
12986+
p.stride = stride;
12987+
p.p0 = p0;
12988+
12989+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COL2IM_1D, std::move(p));
12990+
}
12991+
1293912992
// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
1294012993
// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
1294112994
// dedicated kernel directly. The pattern is validated by
@@ -14423,6 +14476,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1442314476
case GGML_OP_TIMESTEP_EMBEDDING:
1442414477
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
1442514478

14479+
break;
14480+
case GGML_OP_COL2IM_1D:
14481+
ggml_vk_col2im_1d(ctx, compute_ctx, src0, node);
14482+
1442614483
break;
1442714484
case GGML_OP_CONV_TRANSPOSE_1D:
1442814485
ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
@@ -17188,6 +17245,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1718817245
return op->src[0]->type == GGML_TYPE_F32;
1718917246
case GGML_OP_CONV_TRANSPOSE_1D:
1719017247
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
17248+
case GGML_OP_COL2IM_1D:
17249+
return (op->src[0]->type == GGML_TYPE_F32 ||
17250+
op->src[0]->type == GGML_TYPE_F16 ||
17251+
op->src[0]->type == GGML_TYPE_BF16) &&
17252+
op->type == op->src[0]->type &&
17253+
ggml_is_contiguous(op->src[0]) &&
17254+
ggml_is_contiguous(op);
1719117255
case GGML_OP_CONV_2D:
1719217256
case GGML_OP_CONV_TRANSPOSE_2D:
1719317257
{
@@ -18019,6 +18083,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1801918083
const int32_t p0 = tensor->op_params[1];
1802018084
const int32_t d0 = tensor->op_params[2];
1802118085
tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
18086+
} else if (tensor->op == GGML_OP_COL2IM_1D) {
18087+
const int32_t stride = tensor->op_params[0];
18088+
const int32_t oc = tensor->op_params[1];
18089+
const int32_t p0 = tensor->op_params[2];
18090+
tensor_clone = ggml_col2im_1d(ggml_ctx, src_clone[0], stride, oc, p0);
1802218091
} else if (tensor->op == GGML_OP_POOL_2D) {
1802318092
enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
1802418093
const int32_t k0 = tensor->op_params[1];
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#version 450
2+
3+
#include "types.glsl"
4+
5+
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // columns: [K_OC, T_in]
6+
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; // output: [T_out, OC]
7+
8+
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (push_constant) uniform parameter {
11+
uint32_t T_out;
12+
uint32_t OC;
13+
uint32_t K_OC;
14+
uint32_t T_in;
15+
uint32_t K;
16+
int32_t stride;
17+
int32_t p0;
18+
} p;
19+
20+
// Load A_TYPE to float
21+
float load_col(uint32_t idx) {
22+
#if defined(DATA_A_BF16)
23+
return bf16_to_fp32(uint32_t(data_a[idx]));
24+
#else
25+
return float(data_a[idx]);
26+
#endif
27+
}
28+
29+
// Store float as D_TYPE
30+
void store_dst(uint32_t idx, float v) {
31+
#if defined(DATA_A_BF16)
32+
data_d[idx] = D_TYPE(fp32_to_bf16(v));
33+
#else
34+
data_d[idx] = D_TYPE(v);
35+
#endif
36+
}
37+
38+
void main() {
39+
const uint32_t t_out = gl_GlobalInvocationID.x;
40+
const uint32_t oc = gl_GlobalInvocationID.y;
41+
if (t_out >= p.T_out || oc >= p.OC) return;
42+
43+
const int32_t t_abs = int32_t(t_out) + p.p0; // absolute position in uncropped signal
44+
45+
// Gather: only the ceil(K/stride) columns that scatter into t_abs, no modulo
46+
int32_t t_in_min = (t_abs - int32_t(p.K) + p.stride) / p.stride;
47+
if (t_in_min < 0) t_in_min = 0;
48+
int32_t t_in_max = t_abs / p.stride;
49+
if (t_in_max >= int32_t(p.T_in)) t_in_max = int32_t(p.T_in) - 1;
50+
51+
float val = 0.0;
52+
for (int32_t t_in = t_in_min; t_in <= t_in_max; t_in++) {
53+
int32_t k = t_abs - t_in * p.stride;
54+
// col layout: [K_OC, T_in], column index = oc * K + k
55+
uint32_t col_idx = (oc * p.K + uint32_t(k)) + uint32_t(t_in) * p.K_OC;
56+
val += load_col(col_idx);
57+
}
58+
59+
// dst layout: [T_out, OC], element (t_out, oc) = t_out + oc * T_out
60+
store_dst(t_out + oc * p.T_out, val);
61+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,9 @@ void process_shaders() {
10031003
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
10041004

10051005
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
1006+
string_to_spv("col2im_1d_f32", "col2im_1d.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
1007+
string_to_spv("col2im_1d_f16", "col2im_1d.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
1008+
string_to_spv("col2im_1d_bf16", "col2im_1d.comp", {{"DATA_A_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});
10061009

10071010
string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
10081011
string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

0 commit comments

Comments
 (0)