@@ -6774,6 +6774,48 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
67746774 return result;
67756775}
67766776
6777+ GGML_API struct ggml_tensor * ggml_conv_transpose_1d_gemm(
6778+ struct ggml_context * ctx,
6779+ struct ggml_tensor * a, // KW OC IC
6780+ struct ggml_tensor * b, // IW IC N
6781+ int s0,
6782+ int p0,
6783+ int d0) {
6784+ GGML_ASSERT(a->ne[3] == 1);
6785+ GGML_ASSERT(b->ne[3] == 1);
6786+ GGML_ASSERT(a->ne[2] == b->ne[1]);
6787+
6788+ a = ggml_cont(ctx, ggml_permute(ctx, a, 2, 1, 0, 3)); // KW OC IC -> IC OC KW
6789+ b = ggml_permute(ctx, b, 1, 0, 2, 3); // IW IC N -> IC IW N
6790+ if (a->type == b->type)
6791+ b = ggml_cont(ctx, b);
6792+ else
6793+ b = ggml_cast(ctx, b, a->type);
6794+ const int64_t IC = a->ne[0];
6795+ assert(IC == b->ne[0]);
6796+ const int64_t KW = a->ne[2];
6797+ const int64_t OC = a->ne[1];
6798+ const int64_t IW = b->ne[1];
6799+ const int64_t N = b->ne[2];
6800+ // The following line isn't necessary, in theory,
6801+ // but makes CUDA use cublasSgemm instead of
6802+ // cublasGemmBatchedEx.
6803+ // The latter doesn't pass test-backend-ops
6804+ // because of F16 approximations
6805+ a = ggml_reshape_4d(ctx, a, IC, OC*KW, 1, 1);
6806+ b = ggml_reshape_4d(ctx, b, IC, IW*N, 1, 1);
6807+ struct ggml_tensor * mulres = ggml_mul_mat(ctx, b, a);
6808+ mulres = ggml_reshape_4d(ctx, mulres, IW, N, OC, KW);
6809+ mulres = ggml_permute(ctx, mulres, 0, 3, 2, 1); // -> IW KW OC N
6810+ return ggml_col2im(ctx,
6811+ mulres,
6812+ s0, 1 /* s1 */,
6813+ p0, 0 /* p1 */,
6814+ d0, 1 /* d1 */,
6815+ 1 /* KH */,
6816+ 1 /* IH */);
6817+ }
6818+
67776819// ggml_conv_depthwise
67786820struct ggml_tensor * ggml_conv_depthwise_2d(
67796821 struct ggml_context * ctx,
0 commit comments