Skip to content

Commit daf6bc9

Browse files
authored
metal : fix im2col 1D case (audio models) (#24220)
1 parent d403f00 commit daf6bc9

2 files changed

Lines changed: 6 additions & 1 deletion

File tree

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1738,10 +1738,14 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_meta
17381738
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
17391739
GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32);
17401740

1741+
const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1;
1742+
const int64_t KH = is_2D ? ne01 : 1;
1743+
const int64_t KW = ne00;
1744+
17411745
char base[256];
17421746
char name[256];
17431747

1744-
if (ne00*ne01 <= 1024) {
1748+
if (KH*KW <= 1024) {
17451749
snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
17461750
} else {
17471751
snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type));

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7771,6 +7771,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
77717771
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
77727772
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F32, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
77737773
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 128, 1, 1}, {3, 128, 1280, 1}, 1, 0, 1, 0, 1, 0, false));
7774+
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {3000, 384, 1, 1}, {3, 384, 384, 1}, 1, 0, 1, 0, 1, 0, false));
77747775
for (int s0 : {1, 3}) {
77757776
for (int p0 : {0, 3}) {
77767777
for (int d0 : {1, 3}) {

0 commit comments

Comments
 (0)