@@ -31,7 +31,8 @@ void explicit_gemm_conv_1D_gpu(
3131 array in_padded (padded_shape, in.dtype (), nullptr , {});
3232
3333 // Fill with zeros
34- copy_gpu (array (0 , in.dtype ()), in_padded, CopyType::Scalar, s);
34+ auto zero = array (0 , in.dtype ());
35+ copy_gpu (zero, in_padded, CopyType::Scalar, s);
3536
3637 // Pick input slice from padded
3738 size_t data_offset = conv_params.pad [0 ] * in_padded.strides ()[1 ];
@@ -68,7 +69,7 @@ void explicit_gemm_conv_1D_gpu(
6869 copy_gpu (in_strided_view, in_strided, CopyType::General, s);
6970
7071 // Perform gemm
71- std::vector<array> copies = {in_padded, in_strided};
72+ std::vector<array> copies = {zero, in_padded, in_strided};
7273 return steel_matmul (
7374 s,
7475 d,
@@ -213,6 +214,7 @@ void explicit_gemm_conv_2D_gpu(
213214 array in_padded (padded_shape, in.dtype (), nullptr , {});
214215
215216 // Fill with zeros
217+ auto zero = array (0 , in.dtype ());
216218 copy_gpu (array (0 , in.dtype ()), in_padded, CopyType::Scalar, s);
217219
218220 // Pick input slice from padded
@@ -259,7 +261,7 @@ void explicit_gemm_conv_2D_gpu(
259261 copy_gpu (in_strided_view, in_strided, CopyType::General, s);
260262
261263 // Perform gemm
262- std::vector<array> copies = {in_padded, in_strided};
264+ std::vector<array> copies = {zero, in_padded, in_strided};
263265 return steel_matmul (
264266 s,
265267 d,
0 commit comments