Skip to content

Commit 885e62e

Browse files
authored
refactor: replace ggml_ext_attention with ggml_ext_attention_ext (#1185)
1 parent 0e52afc commit 885e62e

File tree

3 files changed

+7
-33
lines changed

3 files changed

+7
-33
lines changed

ggml_extend.hpp

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,35 +1208,11 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_cast_f32(ggml_context* ctx, ggml_tensor*
12081208
} else {
12091209
out = ggml_mul_mat(ctx, out, one);
12101210
}
1211-
out = ggml_reshape(ctx, out, a);
1211+
out = ggml_reshape(ctx, out, a);
12121212
#endif
12131213
return out;
12141214
}
12151215

1216-
// q: [N * n_head, n_token, d_head]
1217-
// k: [N * n_head, n_k, d_head]
1218-
// v: [N * n_head, d_head, n_k]
1219-
// return: [N * n_head, n_token, d_head]
1220-
__STATIC_INLINE__ struct ggml_tensor* ggml_ext_attention(struct ggml_context* ctx,
1221-
struct ggml_tensor* q,
1222-
struct ggml_tensor* k,
1223-
struct ggml_tensor* v,
1224-
bool mask = false) {
1225-
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL)
1226-
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
1227-
#else
1228-
float d_head = (float)q->ne[0];
1229-
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k]
1230-
kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head));
1231-
if (mask) {
1232-
kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
1233-
}
1234-
kq = ggml_soft_max_inplace(ctx, kq);
1235-
struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head]
1236-
#endif
1237-
return kqv;
1238-
}
1239-
12401216
// q: [N, L_q, C(n_head*d_head)] or [N*n_head, L_q, d_head]
12411217
// k: [N, L_k, n_kv_head*d_head] or [N*n_kv_head, L_k, d_head]
12421218
// v: [N, L_k, n_kv_head*d_head] or [N, L_k, n_kv_head, d_head]

vae.hpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,6 @@ class AttnBlock : public UnaryBlock {
127127
q = q_proj->forward(ctx, h_); // [N, h * w, in_channels]
128128
k = k_proj->forward(ctx, h_); // [N, h * w, in_channels]
129129
v = v_proj->forward(ctx, h_); // [N, h * w, in_channels]
130-
131-
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [N, in_channels, h * w]
132130
} else {
133131
q = q_proj->forward(ctx, h_); // [N, in_channels, h, w]
134132
q = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels]
@@ -138,11 +136,12 @@ class AttnBlock : public UnaryBlock {
138136
k = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, k, 1, 2, 0, 3)); // [N, h, w, in_channels]
139137
k = ggml_reshape_3d(ctx->ggml_ctx, k, c, h * w, n); // [N, h * w, in_channels]
140138

141-
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
142-
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [N, in_channels, h * w]
139+
v = v_proj->forward(ctx, h_); // [N, in_channels, h, w]
140+
v = ggml_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, v, 1, 2, 0, 3)); // [N, h, w, in_channels]
141+
v = ggml_reshape_3d(ctx->ggml_ctx, v, c, h * w, n); // [N, h * w, in_channels]
143142
}
144143

145-
h_ = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [N, h * w, in_channels]
144+
h_ = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false);
146145

147146
if (use_linear) {
148147
h_ = proj_out->forward(ctx, h_); // [N, h * w, in_channels]

wan.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,9 +572,8 @@ namespace WAN {
572572
auto v = qkv_vec[2];
573573
v = ggml_reshape_3d(ctx->ggml_ctx, v, h * w, c, n); // [t, c, h * w]
574574

575-
x = ggml_ext_attention(ctx->ggml_ctx, q, k, v, false); // [t, h * w, c]
576-
// v = ggml_cont(ctx, ggml_ext_torch_permute(ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
577-
// x = ggml_ext_attention_ext(ctx, q, k, v, q->ne[2], nullptr, false, false, true);
575+
v = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, v, 1, 0, 2, 3)); // [t, h * w, c]
576+
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, 1, nullptr, false, true, false); // [t, h * w, c]
578577

579578
x = ggml_ext_cont(ctx->ggml_ctx, ggml_permute(ctx->ggml_ctx, x, 1, 0, 2, 3)); // [t, c, h * w]
580579
x = ggml_reshape_4d(ctx->ggml_ctx, x, w, h, c, n); // [t, c, h, w]

0 commit comments

Comments
 (0)