@@ -902,6 +902,9 @@ struct vk_device_struct {
902902 vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
903903 vk_pipeline pipeline_timestep_embedding_f32;
904904 vk_pipeline pipeline_conv_transpose_1d_f32;
905+ vk_pipeline pipeline_col2im_1d_f32;
906+ vk_pipeline pipeline_col2im_1d_f16;
907+ vk_pipeline pipeline_col2im_1d_bf16;
905908 vk_pipeline pipeline_snake_f32;
906909 vk_pipeline pipeline_snake_f16;
907910 vk_pipeline pipeline_snake_bf16;
@@ -1552,6 +1555,16 @@ struct vk_op_timestep_embedding_push_constants {
15521555 uint32_t max_period;
15531556};
15541557
1558+ struct vk_op_col2im_1d_push_constants {
1559+ uint32_t T_out;
1560+ uint32_t OC;
1561+ uint32_t K_OC;
1562+ uint32_t T_in;
1563+ uint32_t K;
1564+ int32_t stride;
1565+ int32_t p0;
1566+ };
1567+
15551568struct vk_op_conv_transpose_1d_push_constants {
15561569 uint32_t Cout;
15571570 uint32_t Cin;
@@ -5203,6 +5216,9 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
52035216 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
52045217
52055218 ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
5219+ ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f32, "col2im_1d_f32", col2im_1d_f32_len, col2im_1d_f32_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
5220+ ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_f16, "col2im_1d_f16", col2im_1d_f16_len, col2im_1d_f16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
5221+ ggml_vk_create_pipeline(device, device->pipeline_col2im_1d_bf16, "col2im_1d_bf16", col2im_1d_bf16_len, col2im_1d_bf16_data, "main", 2, sizeof(vk_op_col2im_1d_push_constants), {256, 1, 1}, {}, 1, true);
52065222
52075223 ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
52085224 ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
@@ -10702,6 +10718,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
1070210718 return ctx->device->pipeline_conv_transpose_1d_f32;
1070310719 }
1070410720 return nullptr;
10721+ case GGML_OP_COL2IM_1D:
10722+ switch (src0->type) {
10723+ case GGML_TYPE_F32: return ctx->device->pipeline_col2im_1d_f32;
10724+ case GGML_TYPE_F16: return ctx->device->pipeline_col2im_1d_f16;
10725+ case GGML_TYPE_BF16: return ctx->device->pipeline_col2im_1d_bf16;
10726+ default: return nullptr;
10727+ }
1070510728 case GGML_OP_POOL_2D:
1070610729 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1070710730 return ctx->device->pipeline_pool2d_f32;
@@ -11147,6 +11170,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
1114711170 {
1114811171 elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1}
1114911172 } break;
11173+ case GGML_OP_COL2IM_1D:
11174+ {
11175+ elements = { uint32_t(dst->ne[0]), uint32_t(dst->ne[1]), 1 };
11176+ } break;
1115011177 case GGML_OP_POOL_2D:
1115111178 {
1115211179 const uint32_t N = dst->ne[3];
@@ -12936,6 +12963,32 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
1293612963 ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
1293712964}
1293812965
12966+ static void ggml_vk_col2im_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
12967+ // src0: [K_OC, T_in] columns from matmul
12968+ // dst: [T_out, OC]
12969+
12970+ const int32_t stride = dst->op_params[0];
12971+ const int32_t oc = dst->op_params[1];
12972+ const int32_t p0 = dst->op_params[2];
12973+
12974+ const uint32_t K_OC = static_cast<uint32_t>(src0->ne[0]);
12975+ const uint32_t T_in = static_cast<uint32_t>(src0->ne[1]);
12976+ const uint32_t T_out = static_cast<uint32_t>(dst->ne[0]);
12977+ const uint32_t OC = static_cast<uint32_t>(oc);
12978+ const uint32_t K = K_OC / OC;
12979+
12980+ vk_op_col2im_1d_push_constants p{};
12981+ p.T_out = T_out;
12982+ p.OC = OC;
12983+ p.K_OC = K_OC;
12984+ p.T_in = T_in;
12985+ p.K = K;
12986+ p.stride = stride;
12987+ p.p0 = p0;
12988+
12989+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COL2IM_1D, std::move(p));
12990+ }
12991+
1293912992// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
1294012993// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
1294112994// dedicated kernel directly. The pattern is validated by
@@ -14423,6 +14476,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1442314476 case GGML_OP_TIMESTEP_EMBEDDING:
1442414477 ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node);
1442514478
14479+ break;
14480+ case GGML_OP_COL2IM_1D:
14481+ ggml_vk_col2im_1d(ctx, compute_ctx, src0, node);
14482+
1442614483 break;
1442714484 case GGML_OP_CONV_TRANSPOSE_1D:
1442814485 ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node);
@@ -17188,6 +17245,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1718817245 return op->src[0]->type == GGML_TYPE_F32;
1718917246 case GGML_OP_CONV_TRANSPOSE_1D:
1719017247 return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
17248+ case GGML_OP_COL2IM_1D:
17249+ return (op->src[0]->type == GGML_TYPE_F32 ||
17250+ op->src[0]->type == GGML_TYPE_F16 ||
17251+ op->src[0]->type == GGML_TYPE_BF16) &&
17252+ op->type == op->src[0]->type &&
17253+ ggml_is_contiguous(op->src[0]) &&
17254+ ggml_is_contiguous(op);
1719117255 case GGML_OP_CONV_2D:
1719217256 case GGML_OP_CONV_TRANSPOSE_2D:
1719317257 {
@@ -18019,6 +18083,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1801918083 const int32_t p0 = tensor->op_params[1];
1802018084 const int32_t d0 = tensor->op_params[2];
1802118085 tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0);
18086+ } else if (tensor->op == GGML_OP_COL2IM_1D) {
18087+ const int32_t stride = tensor->op_params[0];
18088+ const int32_t oc = tensor->op_params[1];
18089+ const int32_t p0 = tensor->op_params[2];
18090+ tensor_clone = ggml_col2im_1d(ggml_ctx, src_clone[0], stride, oc, p0);
1802218091 } else if (tensor->op == GGML_OP_POOL_2D) {
1802318092 enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
1802418093 const int32_t k0 = tensor->op_params[1];
0 commit comments