Skip to content

Commit f5f18b7

Browse files
authored
fix temporary bug (#752)
1 parent 420ff2f commit f5f18b7

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

mlx/backend/metal/conv.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ void explicit_gemm_conv_1D_gpu(
3131
array in_padded(padded_shape, in.dtype(), nullptr, {});
3232

3333
// Fill with zeros
34-
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
34+
auto zero = array(0, in.dtype());
35+
copy_gpu(zero, in_padded, CopyType::Scalar, s);
3536

3637
// Pick input slice from padded
3738
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1];
@@ -68,7 +69,7 @@ void explicit_gemm_conv_1D_gpu(
6869
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
6970

7071
// Perform gemm
71-
std::vector<array> copies = {in_padded, in_strided};
72+
std::vector<array> copies = {zero, in_padded, in_strided};
7273
return steel_matmul(
7374
s,
7475
d,
@@ -213,6 +214,7 @@ void explicit_gemm_conv_2D_gpu(
213214
array in_padded(padded_shape, in.dtype(), nullptr, {});
214215

215216
// Fill with zeros
217+
auto zero = array(0, in.dtype());
216218
copy_gpu(array(0, in.dtype()), in_padded, CopyType::Scalar, s);
217219

218220
// Pick input slice from padded
@@ -259,7 +261,7 @@ void explicit_gemm_conv_2D_gpu(
259261
copy_gpu(in_strided_view, in_strided, CopyType::General, s);
260262

261263
// Perform gemm
262-
std::vector<array> copies = {in_padded, in_strided};
264+
std::vector<array> copies = {zero, in_padded, in_strided};
263265
return steel_matmul(
264266
s,
265267
d,

0 commit comments

Comments
 (0)