@@ -127,8 +127,6 @@ class AttnBlock : public UnaryBlock {
127127 q = q_proj->forward (ctx, h_); // [N, h * w, in_channels]
128128 k = k_proj->forward (ctx, h_); // [N, h * w, in_channels]
129129 v = v_proj->forward (ctx, h_); // [N, h * w, in_channels]
130-
131- v = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , v, 1 , 0 , 2 , 3 )); // [N, in_channels, h * w]
132130 } else {
133131 q = q_proj->forward (ctx, h_); // [N, in_channels, h, w]
134132 q = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , q, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
@@ -138,11 +136,12 @@ class AttnBlock : public UnaryBlock {
138136 k = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , k, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
139137 k = ggml_reshape_3d (ctx->ggml_ctx , k, c, h * w, n); // [N, h * w, in_channels]
140138
141- v = v_proj->forward (ctx, h_); // [N, in_channels, h, w]
142- v = ggml_reshape_3d (ctx->ggml_ctx , v, h * w, c, n); // [N, in_channels, h * w]
139+ v = v_proj->forward (ctx, h_); // [N, in_channels, h, w]
140+ v = ggml_cont (ctx->ggml_ctx , ggml_permute (ctx->ggml_ctx , v, 1 , 2 , 0 , 3 )); // [N, h, w, in_channels]
141+ v = ggml_reshape_3d (ctx->ggml_ctx , v, c, h * w, n); // [N, h * w, in_channels]
143142 }
144143
145- h_ = ggml_ext_attention (ctx->ggml_ctx , q, k, v, false ); // [N, h * w, in_channels]
144+ h_ = ggml_ext_attention_ext (ctx->ggml_ctx , ctx-> backend , q, k, v, 1 , nullptr , false , true , false );
146145
147146 if (use_linear) {
148147 h_ = proj_out->forward (ctx, h_); // [N, h * w, in_channels]
0 commit comments