@@ -533,45 +533,6 @@ void implicit_gemm_conv_2D_general_gpu(
533533 compute_encoder.dispatch_threadgroups (grid_dims, group_dims);
534534}
535535
536- void winograd_conv_2D_fused_gpu (
537- const Stream& s,
538- metal::Device& d,
539- const array& in,
540- const array& wt,
541- array out,
542- const MLXConvParams<2 >& conv_params,
543- std::vector<array>& copies_w) {
544- int O_c = conv_params.O ;
545- int C_c = conv_params.C ;
546-
547- int N_tiles_n = conv_params.N ;
548- int N_tiles_h = (conv_params.oS [0 ] + 1 ) / 2 ;
549- int N_tiles_w = (conv_params.oS [1 ] + 1 ) / 2 ;
550- int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
551-
552- int bc = 32 ;
553- int wm = 4 ;
554- int wn = 1 ;
555- std::ostringstream kname;
556- kname << " winograd_conv_2d_fused_" << type_to_name (out) << " _flip"
557- << conv_params.flip ;
558- auto & compute_encoder = d.get_command_encoder (s.index );
559- auto kernel = d.get_kernel (kname.str ());
560- compute_encoder.set_compute_pipeline_state (kernel);
561-
562- compute_encoder.set_input_array (in, 0 );
563- compute_encoder.set_input_array (wt, 1 );
564- compute_encoder.set_output_array (out, 2 );
565-
566- compute_encoder.set_bytes (conv_params, 3 );
567-
568- MTL::Size group_dims = MTL::Size (8 , 8 , 2 );
569- MTL::Size grid_dims =
570- MTL::Size (O_c / 8 , (N_tiles_h * N_tiles_w) / 8 , N_tiles_n);
571-
572- compute_encoder.dispatch_threadgroups (grid_dims, group_dims);
573- }
574-
575536void winograd_conv_2D_gpu (
576537 const Stream& s,
577538 metal::Device& d,
@@ -580,6 +541,67 @@ void winograd_conv_2D_gpu(
580541 array out,
581542 const MLXConvParams<2 >& conv_params,
582543 std::vector<array>& copies_w) {
544+ Shape padded_shape = {
545+ conv_params.N ,
546+ conv_params.iS [0 ] + 2 * conv_params.pad [0 ],
547+ conv_params.iS [1 ] + 2 * conv_params.pad [1 ],
548+ conv_params.C };
549+
550+ padded_shape[1 ] = 6 * ((padded_shape[1 ] - 2 + 5 ) / 6 ) + 2 ;
551+ padded_shape[2 ] = 6 * ((padded_shape[2 ] - 2 + 5 ) / 6 ) + 2 ;
552+
553+ array in_padded (std::move (padded_shape), in.dtype (), nullptr , {});
554+
555+ // Fill with zeros
556+ array zero_arr = array (0 , in.dtype ());
557+ fill_gpu (zero_arr, in_padded, s);
558+ copies_w.push_back (zero_arr);
559+
560+ // Pick input slice from padded
561+ size_t data_offset = conv_params.pad [0 ] * in_padded.strides ()[1 ] +
562+ conv_params.pad [1 ] * in_padded.strides ()[2 ];
563+ array in_padded_slice (in.shape (), in_padded.dtype (), nullptr , {});
564+ in_padded_slice.copy_shared_buffer (
565+ in_padded,
566+ in_padded.strides (),
567+ in_padded.flags (),
568+ in_padded_slice.size (),
569+ data_offset);
570+
571+ // Copy input values into the slice
572+ copy_gpu_inplace (in, in_padded_slice, CopyType::GeneralGeneral, s);
573+
574+ copies_w.push_back (in_padded_slice);
575+ copies_w.push_back (in_padded);
576+
577+ MLXConvParams<2 > conv_params_updated{
578+ /* const int N = */ static_cast <int >(in_padded.shape (0 )),
579+ /* const int C = */ static_cast <int >(in_padded.shape (3 )),
580+ /* const int O = */ static_cast <int >(wt.shape (0 )),
581+ /* const int iS[NDIM] = */
582+ {static_cast <int >(in_padded.shape (1 )),
583+ static_cast <int >(in_padded.shape (2 ))},
584+ /* const int wS[NDIM] = */
585+ {static_cast <int >(wt.shape (1 )), static_cast <int >(wt.shape (2 ))},
586+ /* const int oS[NDIM] = */
587+ {static_cast <int >(out.shape (1 )), static_cast <int >(out.shape (2 ))},
588+ /* const int str[NDIM] = */ {1 , 1 },
589+ /* const int pad[NDIM] = */ {0 , 0 },
590+ /* const int kdil[NDIM] = */ {1 , 1 },
591+ /* const int idil[NDIM] = */ {1 , 1 },
592+ /* const size_t in_strides[NDIM + 2] = */
593+ {in_padded.strides ()[0 ],
594+ in_padded.strides ()[1 ],
595+ in_padded.strides ()[2 ],
596+ in_padded.strides ()[3 ]},
597+ /* const size_t wt_strides[NDIM + 2] = */
598+ {wt.strides ()[0 ], wt.strides ()[1 ], wt.strides ()[2 ], wt.strides ()[3 ]},
599+ /* const size_t out_strides[NDIM + 2] = */
600+ {out.strides ()[0 ], out.strides ()[1 ], out.strides ()[2 ], out.strides ()[3 ]},
601+ /* const int groups = */ 1 ,
602+ /* const bool flip = */ false ,
603+ };
604+
583605 int O_c = conv_params.O ;
584606 int C_c = conv_params.C ;
585607
@@ -598,7 +620,7 @@ void winograd_conv_2D_gpu(
598620 int bo = 4 ;
599621 std::ostringstream kname;
600622 kname << " winograd_conv_2d_weight_transform_" << type_to_name (out) << " _bc"
601- << bc << " _flip " << conv_params. flip ;
623+ << bc;
602624 auto & compute_encoder = d.get_command_encoder (s.index );
603625 auto kernel = d.get_kernel (kname.str ());
604626 compute_encoder.set_compute_pipeline_state (kernel);
@@ -631,10 +653,10 @@ void winograd_conv_2D_gpu(
631653 auto kernel = d.get_kernel (kname.str ());
632654 compute_encoder.set_compute_pipeline_state (kernel);
633655
634- compute_encoder.set_input_array (in , 0 );
656+ compute_encoder.set_input_array (in_padded , 0 );
635657 compute_encoder.set_output_array (inp_wg, 1 );
636658
637- compute_encoder.set_bytes (conv_params , 2 );
659+ compute_encoder.set_bytes (conv_params_updated , 2 );
638660
639661 MTL::Size group_dims = MTL::Size (32 , wn, wm);
640662 MTL::Size grid_dims = MTL::Size (N_tiles_w, N_tiles_h, N_tiles_n);
@@ -681,7 +703,7 @@ void winograd_conv_2D_gpu(
681703 compute_encoder.set_input_array (out_wg, 0 );
682704 compute_encoder.set_output_array (out, 1 );
683705
684- compute_encoder.set_bytes (conv_params , 2 );
706+ compute_encoder.set_bytes (conv_params_updated , 2 );
685707
686708 MTL::Size group_dims = MTL::Size (32 , wn, wm);
687709 MTL::Size grid_dims = MTL::Size (N_tiles_w, N_tiles_h, N_tiles_n);
@@ -745,18 +767,14 @@ void conv_2D_gpu(
745767 }
746768
747769 // Direct to winograd conv
748- bool img_large =
770+ bool inp_large =
749771 (conv_params.N * conv_params.iS [0 ] * conv_params.iS [1 ]) >= 1ul << 12 ;
750772 bool channels_large = (conv_params.C + conv_params.O ) >= 256 ;
751- if (conv_params.wS [0 ] == 3 && conv_params.wS [1 ] == 3 &&
752- conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
753- is_kdil_one && is_idil_one) {
754- if (img_large && channels_large) {
755- return winograd_conv_2D_gpu (s, d, in, wt, out, conv_params, copies);
756- }
757- if (conv_params.N <= 1 ) {
758- return winograd_conv_2D_fused_gpu (s, d, in, wt, out, conv_params, copies);
759- }
773+ if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
774+ conv_params.wS [0 ] == 3 && conv_params.wS [1 ] == 3 &&
775+ conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
776+ channels_large) {
777+ return winograd_conv_2D_gpu (s, d, in, wt, out, conv_params, copies);
760778 }
761779
762780 // Direct to implicit gemm conv
@@ -858,40 +876,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
858876 wt = arr_copy;
859877 }
860878
861- // Check for 1x1 conv
862- auto is_one = [](int x) { return x == 1 ; };
863- auto is_zero = [](int x) { return x == 0 ; };
864- if (groups_ == 1 && (wt.shape (0 ) * wt.shape (-1 ) == wt.size ()) &&
865- std::all_of (wt.shape ().begin () + 1 , wt.shape ().end () - 1 , is_one) &&
866- std::all_of (kernel_strides_.begin (), kernel_strides_.end (), is_one) &&
867- std::all_of (input_dilation_.begin (), input_dilation_.end (), is_one) &&
868- std::all_of (kernel_dilation_.begin (), kernel_dilation_.end (), is_one) &&
869- std::all_of (padding_.begin (), padding_.end (), is_zero)) {
870- std::vector<array> empty_copies;
871- steel_matmul_regular (
872- s,
873- d,
874- /* a = */ in,
875- /* b = */ wt,
876- /* c = */ out,
877- /* M = */ in.size () / in.shape (-1 ),
878- /* N = */ wt.shape (0 ),
879- /* K = */ in.shape (-1 ),
880- /* batch_size_out = */ 1 ,
881- /* lda = */ in.shape (-1 ),
882- /* ldb = */ wt.shape (-1 ),
883- /* ldd = */ wt.shape (0 ),
884- /* transpose_a = */ false ,
885- /* transpose_b = */ true ,
886- /* batch_shape = */ {1 },
887- /* batch_strides = */ {1 },
888- /* A_batch_stride = */ 0 ,
889- /* B_batch_stride = */ 0 ,
890- /* matrix_stride_out = */ 0 ,
891- /* copies = */ empty_copies);
892- }
893879 // 3D conv
894- else if (out.ndim () == 5 ) {
880+ if (out.ndim () == 5 ) {
895881 conv_3D_gpu (
896882 s,
897883 d,
0 commit comments