Skip to content

Commit 7e876f4

Browse files
author
“AlexiAlp”
committed
使用自定义卷积替换ggml_conv_1d,防止metal只支持im2col fp16
1 parent 5843802 commit 7e876f4

3 files changed

Lines changed: 153 additions & 7 deletions

File tree

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1746,7 +1746,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
17461746
case GGML_OP_COS:
17471747
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
17481748
case GGML_OP_LOG:
1749-
return false; // TODO: implement
1749+
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1750+
// return false; // TODO: implement
17501751
case GGML_OP_SUM_ROWS:
17511752
case GGML_OP_MEAN:
17521753
case GGML_OP_SOFT_MAX:

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,13 @@ kernel void kernel_cos(
11921192
dst[tpig] = cos(src0[tpig]);
11931193
}
11941194

1195+
kernel void kernel_log(
1196+
device const float * src0,
1197+
device float * dst,
1198+
uint tgpig [[thread_position_in_grid]]) {
1199+
dst[i] = log(src0[tpig]);
1200+
}
1201+
11951202
kernel void kernel_neg(
11961203
device const float * src0,
11971204
device float * dst,

src/llama-graph.cpp

Lines changed: 144 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,67 @@ ggml_tensor * llm_graph_context::flip_weight(ggml_cgraph * gf, ggml_tensor * con
782782
}
783783
}
784784

785+
static ggml_tensor * conv1d_s1_p0_d1_mul_mat(ggml_context * ctx, ggml_tensor * w_in, ggml_tensor * x_in) {
786+
// ---- 1) 归一化输入 x 为 2D: (T, Cin)
787+
ggml_tensor * x2 = x_in;
788+
if (ggml_n_dims(x2) == 4) {
789+
GGML_ASSERT(x2->ne[2] == 1 && x2->ne[3] == 1);
790+
x2 = ggml_reshape_2d(ctx, x2, x2->ne[0], x2->ne[1]);
791+
} else {
792+
GGML_ASSERT(ggml_n_dims(x2) == 2);
793+
}
794+
795+
const int64_t T = x2->ne[0];
796+
const int64_t Cin = x2->ne[1];
797+
798+
// ---- 2) 归一化权重 w 为 3D: (K, Cin, Cout)
799+
ggml_tensor * w3 = w_in;
800+
if (ggml_n_dims(w3) == 4) {
801+
GGML_ASSERT(w3->ne[3] == 1);
802+
w3 = ggml_reshape_3d(ctx, w3, w3->ne[0], w3->ne[1], w3->ne[2]);
803+
} else {
804+
GGML_ASSERT(ggml_n_dims(w3) == 3);
805+
}
806+
807+
const int64_t K = w3->ne[0];
808+
const int64_t CinW = w3->ne[1];
809+
const int64_t Cout = w3->ne[2];
810+
811+
GGML_ASSERT(CinW == Cin);
812+
813+
const int64_t Tout = T - K + 1;
814+
GGML_ASSERT(Tout > 0);
815+
816+
// ---- 3) xt: (Cin, T),方便按时间取 view
817+
ggml_tensor * xt = ggml_cont(ctx, ggml_transpose(ctx, x2));
818+
819+
// ---- 4) 构造 X_cols: (Cin*K, Tout)
820+
// 每个 xk 是 (Cin, Tout),从 xt 的 time 维偏移 k
821+
const size_t st = xt->nb[1]; // time 维步长(bytes)
822+
ggml_tensor * x_cols = nullptr;
823+
824+
for (int64_t k = 0; k < K; ++k) {
825+
ggml_tensor * xk = ggml_view_2d(ctx, xt, Cin, Tout, st, (size_t)k * st);
826+
x_cols = (x_cols == nullptr) ? xk : ggml_concat(ctx, x_cols, xk, 0); // 在 dim0 拼接
827+
}
828+
x_cols = ggml_cont(ctx, x_cols);
829+
830+
// ---- 5) 构造 W2D: (Cin*K, Cout)
831+
// w3: (K, Cin, Cout) -> (Cin, K, Cout) 再 reshape,展开顺序与 x_cols 对齐
832+
ggml_tensor * w2d = ggml_reshape_2d(
833+
ctx,
834+
ggml_cont(ctx, ggml_permute(ctx, w3, 1, 0, 2, 3)), // (Cin, K, Cout)
835+
Cin * K, Cout
836+
);
837+
838+
// ---- 6) GEMM: (Cout, Tout)
839+
ggml_tensor * y_ct = ggml_mul_mat(ctx, w2d, x_cols);
840+
841+
// ---- 7) 转回 time-first 并 reshape 成 4D:{Tout, Cout, 1, 1}
842+
ggml_tensor * y = ggml_cont(ctx, ggml_transpose(ctx, y_ct)); // (Tout, Cout)
843+
return ggml_reshape_4d(ctx, y, Tout, Cout, 1, 1);
844+
}
845+
785846
ggml_tensor * llm_graph_context::build_pre_lookahead_layer(
786847
ggml_tensor * cur,
787848
ggml_tensor * conv1_mw,
@@ -794,7 +855,8 @@ ggml_tensor * llm_graph_context::build_pre_lookahead_layer(
794855
x = ggml_reshape_4d(ctx0, x, x->ne[0], x->ne[1], 1, 1);
795856
x = ggml_pad(ctx0, x, lookahead, 0, 0, 0);
796857
ggml_set_name(x, "x_pad");
797-
ggml_tensor * outputs = ggml_conv_1d(ctx0, conv1_mw, x, 1, 0, 1);
858+
// ggml_tensor * outputs = ggml_conv_1d(ctx0, conv1_mw, x, 1, 0, 1);
859+
ggml_tensor * outputs = conv1d_s1_p0_d1_mul_mat(ctx0, conv1_mw, x);
798860
conv1_mb = ggml_reshape_4d(ctx0, ggml_cont(ctx0, conv1_mb), 1, 512, 1, 1);
799861
outputs = ggml_add(ctx0, outputs, conv1_mb);
800862
ggml_set_name(outputs, "x_conv_1d");
@@ -805,7 +867,8 @@ ggml_tensor * llm_graph_context::build_pre_lookahead_layer(
805867
outputs = ggml_concat(ctx0, zeros, outputs, 0);
806868
outputs = ggml_cont(ctx0, outputs);
807869
ggml_set_name(outputs, "x_pad_2");
808-
outputs = ggml_conv_1d(ctx0, conv2_mw, outputs, 1, 0, 1);
870+
// outputs = ggml_conv_1d(ctx0, conv2_mw, outputs, 1, 0, 1);
871+
outputs = conv1d_s1_p0_d1_mul_mat(ctx0, conv2_mw, outputs);
809872
conv2_mb = ggml_reshape_4d(ctx0, ggml_cont(ctx0, conv2_mb), 1, 512, 1, 1);
810873
outputs = ggml_add(ctx0, outputs, conv2_mb);
811874
ggml_set_name(outputs, "x_conv_1d_2");
@@ -942,7 +1005,8 @@ ggml_tensor * llm_graph_context::build_upsample_1d(
9421005
ggml_tensor * pad = ggml_concat(ctx0, zeros, up, 1);
9431006
cb(pad, "upsample_pad", -1);
9441007
pad = ggml_cont(ctx0, ggml_permute(ctx0, pad, 1, 0, 2, 3));
945-
ggml_tensor * out = ggml_conv_1d(ctx0, mw, pad, 1, 0, 1);
1008+
// ggml_tensor * out = ggml_conv_1d(ctx0, mw, pad, 1, 0, 1);
1009+
ggml_tensor * out = conv1d_s1_p0_d1_mul_mat(ctx0, mw, pad);
9461010
mb = ggml_reshape_3d(ctx0, ggml_cont(ctx0, mb), 1, cur->ne[0], 1);
9471011
out = ggml_add(ctx0, out, mb);
9481012
cb(out, "upsample_conv_1d", -1);
@@ -1148,6 +1212,79 @@ ggml_tensor * llm_graph_context::build_basic_attn(
11481212
return attn_flat;
11491213
}
11501214

1215+
static ggml_tensor * conv1d_s1_p0_d1_mul_mat_batched(
1216+
ggml_context * ctx,
1217+
ggml_tensor * w_in, // (K, Cin, Cout, 1) 或 (K, Cin, Cout)
1218+
ggml_tensor * x_in // (T, Cin, B, 1) 或 (T, Cin, 1, 1)
1219+
) {
1220+
// ---- normalize x to 4D: (T, Cin, B, 1)
1221+
ggml_tensor * x = x_in;
1222+
if (ggml_n_dims(x) == 2) {
1223+
x = ggml_reshape_4d(ctx, x, x->ne[0], x->ne[1], 1, 1);
1224+
}
1225+
// GGML_ASSERT(ggml_n_dims(x) == 4);
1226+
GGML_ASSERT(x->ne[3] == 1);
1227+
1228+
const int64_t T = x->ne[0];
1229+
const int64_t Cin = x->ne[1];
1230+
const int64_t B = x->ne[2];
1231+
1232+
// ---- normalize w to 4D: (K, Cin, Cout, 1)
1233+
ggml_tensor * w = w_in;
1234+
if (ggml_n_dims(w) == 3) {
1235+
w = ggml_reshape_4d(ctx, w, w->ne[0], w->ne[1], w->ne[2], 1);
1236+
}
1237+
// GGML_ASSERT(ggml_n_dims(w) == 3);
1238+
GGML_ASSERT(w->ne[3] == 1);
1239+
1240+
const int64_t K = w->ne[0];
1241+
const int64_t CinW = w->ne[1];
1242+
const int64_t Cout = w->ne[2];
1243+
1244+
GGML_ASSERT(CinW == Cin);
1245+
const int64_t Tout = T - K + 1;
1246+
GGML_ASSERT(Tout > 0);
1247+
1248+
// ---- xt: (Cin, T, B, 1)
1249+
// x is (T, Cin, B, 1) -> permute to (Cin, T, B, 1)
1250+
ggml_tensor * xt = ggml_cont(ctx, ggml_permute(ctx, x, 1, 0, 2, 3));
1251+
1252+
// ---- X_cols: (Cin*K, Tout, B, 1)
1253+
// take K windows along time dimension (ne1), each view is (Cin, Tout, B)
1254+
const size_t st = xt->nb[1]; // time stride in bytes
1255+
ggml_tensor * x_cols = nullptr;
1256+
1257+
for (int64_t k = 0; k < K; ++k) {
1258+
const size_t off = (size_t) k * st;
1259+
1260+
ggml_tensor * xk = ggml_view_3d(
1261+
ctx,
1262+
xt,
1263+
Cin, Tout, B,
1264+
xt->nb[1], xt->nb[2],
1265+
off
1266+
);
1267+
1268+
x_cols = (x_cols == nullptr) ? xk : ggml_concat(ctx, x_cols, xk, 0);
1269+
}
1270+
x_cols = ggml_cont(ctx, x_cols); // (Cin*K, Tout, B)
1271+
1272+
// ---- W2D: (Cin*K, Cout)
1273+
// w: (K, Cin, Cout, 1) -> (Cin, K, Cout, 1) -> reshape to (Cin*K, Cout)
1274+
ggml_tensor * w2d = ggml_reshape_2d(
1275+
ctx,
1276+
ggml_cont(ctx, ggml_permute(ctx, w, 1, 0, 2, 3)), // (Cin, K, Cout, 1)
1277+
Cin * K, Cout
1278+
);
1279+
1280+
// ---- y_ct: (Cout, Tout, B)
1281+
ggml_tensor * y_ct = ggml_mul_mat(ctx, w2d, x_cols); // (Cout, Tout, B)
1282+
1283+
// ---- y: (Tout, Cout, B, 1)
1284+
ggml_tensor * y = ggml_cont(ctx, ggml_permute(ctx, y_ct, 1, 0, 2, 3));
1285+
return ggml_reshape_4d(ctx, y, Tout, Cout, B, 1);
1286+
}
1287+
11511288
ggml_tensor * llm_graph_context::causal_conv1d_forward(
11521289
ggml_tensor * x,
11531290
std::string mode,
@@ -1217,9 +1354,10 @@ ggml_tensor * llm_graph_context::causal_conv1d_forward(
12171354
x_pad->ne[0], x_pad->ne[1], 1, // [1536, 320, 1]
12181355
x_pad->nb[1], x_pad->nb[2],
12191356
x_pad->nb[2]);
1220-
ggml_tensor * y0 = ggml_conv_1d(ctx0, model_weight, x_batch0, 1, 0, 1);
1221-
ggml_tensor * y1 = ggml_conv_1d(ctx0, model_weight, x_batch1, 1, 0, 1);
1222-
y = ggml_concat(ctx0, y0, y1, 2);
1357+
// ggml_tensor * y0 = ggml_conv_1d(ctx0, model_weight, x_batch0, 1, 0, 1);
1358+
// ggml_tensor * y1 = ggml_conv_1d(ctx0, model_weight, x_batch1, 1, 0, 1);
1359+
// y = ggml_concat(ctx0, y0, y1, 2);
1360+
y = conv1d_s1_p0_d1_mul_mat_batched(ctx0, model_weight, x_pad);
12231361
} else {
12241362
y = ggml_conv_1d(ctx0, model_weight, x_pad, 1, 0, 1);
12251363
}

0 commit comments

Comments
 (0)