Skip to content

Commit 3de2490

Browse files
committed
Add conv_transpose_1d_gemm operator
Signed-off-by: Salvatore Mesoraca <[email protected]>
1 parent 19c5440 commit 3de2490

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

include/ggml.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,6 +1663,14 @@ extern "C" {
16631663
int p0, // padding
16641664
int d0); // dilation
16651665

1666+
GGML_API struct ggml_tensor * ggml_conv_transpose_1d_gemm(
1667+
struct ggml_context * ctx,
1668+
struct ggml_tensor * a, // convolution kernel
1669+
struct ggml_tensor * b, // data
1670+
int s0, // stride
1671+
int p0, // padding
1672+
int d0); // dilation
1673+
16661674
GGML_API struct ggml_tensor * ggml_conv_2d(
16671675
struct ggml_context * ctx,
16681676
struct ggml_tensor * a, // convolution kernel

src/ggml.c

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6774,6 +6774,48 @@ GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
67746774
return result;
67756775
}
67766776

6777+
GGML_API struct ggml_tensor * ggml_conv_transpose_1d_gemm(
6778+
struct ggml_context * ctx,
6779+
struct ggml_tensor * a, // KW OC IC
6780+
struct ggml_tensor * b, // IW IC N
6781+
int s0,
6782+
int p0,
6783+
int d0) {
6784+
GGML_ASSERT(a->ne[3] == 1);
6785+
GGML_ASSERT(b->ne[3] == 1);
6786+
GGML_ASSERT(a->ne[2] == b->ne[1]);
6787+
6788+
a = ggml_cont(ctx, ggml_permute(ctx, a, 2, 1, 0, 3)); // KW OC IC -> IC OC KW
6789+
b = ggml_permute(ctx, b, 1, 0, 2, 3); // IW IC N -> IC IW N
6790+
if (a->type == b->type)
6791+
b = ggml_cont(ctx, b);
6792+
else
6793+
b = ggml_cast(ctx, b, a->type);
6794+
const int64_t IC = a->ne[0];
6795+
assert(IC == b->ne[0]);
6796+
const int64_t KW = a->ne[2];
6797+
const int64_t OC = a->ne[1];
6798+
const int64_t IW = b->ne[1];
6799+
const int64_t N = b->ne[2];
6800+
// The following line isn't necessary, in theory,
6801+
// but makes CUDA use cublasSgemm instead of
6802+
// cublasGemmBatchedEx.
6803+
// The latter doesn't pass test-backend-ops
6804+
// because of F16 approximations
6805+
a = ggml_reshape_4d(ctx, a, IC, OC*KW, 1, 1);
6806+
b = ggml_reshape_4d(ctx, b, IC, IW*N, 1, 1);
6807+
struct ggml_tensor * mulres = ggml_mul_mat(ctx, b, a);
6808+
mulres = ggml_reshape_4d(ctx, mulres, IW, N, OC, KW);
6809+
mulres = ggml_permute(ctx, mulres, 0, 3, 2, 1); // -> IW KW OC N
6810+
return ggml_col2im(ctx,
6811+
mulres,
6812+
s0, 1 /* s1 */,
6813+
p0, 0 /* p1 */,
6814+
d0, 1 /* d1 */,
6815+
1 /* KH */,
6816+
1 /* IH */);
6817+
}
6818+
67776819
// ggml_conv_depthwise
67786820
struct ggml_tensor * ggml_conv_depthwise_2d(
67796821
struct ggml_context * ctx,

0 commit comments

Comments
 (0)