@@ -533,75 +533,53 @@ void implicit_gemm_conv_2D_general_gpu(
533533 compute_encoder.dispatch_threadgroups (grid_dims, group_dims);
534534}
535535
536- void winograd_conv_2D_gpu (
536+ void winograd_conv_2D_fused_gpu (
537537 const Stream& s,
538538 metal::Device& d,
539539 const array& in,
540540 const array& wt,
541541 array out,
542542 const MLXConvParams<2 >& conv_params,
543543 std::vector<array>& copies_w) {
544- Shape padded_shape = {
545- conv_params.N ,
546- conv_params.iS [0 ] + 2 * conv_params.pad [0 ],
547- conv_params.iS [1 ] + 2 * conv_params.pad [1 ],
548- conv_params.C };
549-
550- padded_shape[1 ] = 6 * ((padded_shape[1 ] - 2 + 5 ) / 6 ) + 2 ;
551- padded_shape[2 ] = 6 * ((padded_shape[2 ] - 2 + 5 ) / 6 ) + 2 ;
552-
553- array in_padded (std::move (padded_shape), in.dtype (), nullptr , {});
554-
555- // Fill with zeros
556- array zero_arr = array (0 , in.dtype ());
557- fill_gpu (zero_arr, in_padded, s);
558- copies_w.push_back (zero_arr);
559-
560- // Pick input slice from padded
561- size_t data_offset = conv_params.pad [0 ] * in_padded.strides ()[1 ] +
562- conv_params.pad [1 ] * in_padded.strides ()[2 ];
563- array in_padded_slice (in.shape (), in_padded.dtype (), nullptr , {});
564- in_padded_slice.copy_shared_buffer (
565- in_padded,
566- in_padded.strides (),
567- in_padded.flags (),
568- in_padded_slice.size (),
569- data_offset);
570-
571- // Copy input values into the slice
572- copy_gpu_inplace (in, in_padded_slice, CopyType::GeneralGeneral, s);
573-
574- copies_w.push_back (in_padded_slice);
575- copies_w.push_back (in_padded);
576-
577- MLXConvParams<2 > conv_params_updated{
578- /* const int N = */ static_cast <int >(in_padded.shape (0 )),
579- /* const int C = */ static_cast <int >(in_padded.shape (3 )),
580- /* const int O = */ static_cast <int >(wt.shape (0 )),
581- /* const int iS[NDIM] = */
582- {static_cast <int >(in_padded.shape (1 )),
583- static_cast <int >(in_padded.shape (2 ))},
584- /* const int wS[NDIM] = */
585- {static_cast <int >(wt.shape (1 )), static_cast <int >(wt.shape (2 ))},
586- /* const int oS[NDIM] = */
587- {static_cast <int >(out.shape (1 )), static_cast <int >(out.shape (2 ))},
588- /* const int str[NDIM] = */ {1 , 1 },
589- /* const int pad[NDIM] = */ {0 , 0 },
590- /* const int kdil[NDIM] = */ {1 , 1 },
591- /* const int idil[NDIM] = */ {1 , 1 },
592- /* const size_t in_strides[NDIM + 2] = */
593- {in_padded.strides ()[0 ],
594- in_padded.strides ()[1 ],
595- in_padded.strides ()[2 ],
596- in_padded.strides ()[3 ]},
597- /* const size_t wt_strides[NDIM + 2] = */
598- {wt.strides ()[0 ], wt.strides ()[1 ], wt.strides ()[2 ], wt.strides ()[3 ]},
599- /* const size_t out_strides[NDIM + 2] = */
600- {out.strides ()[0 ], out.strides ()[1 ], out.strides ()[2 ], out.strides ()[3 ]},
601- /* const int groups = */ 1 ,
602- /* const bool flip = */ false ,
603- };
544+ int O_c = conv_params.O ;
545+ int C_c = conv_params.C ;
546+
547+ int N_tiles_n = conv_params.N ;
548+ int N_tiles_h = (conv_params.oS [0 ] + 1 ) / 2 ;
549+ int N_tiles_w = (conv_params.oS [1 ] + 1 ) / 2 ;
550+ int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
551+
552+ int bc = 32 ;
553+ int wm = 4 ;
554+ int wn = 1 ;
555+ std::ostringstream kname;
556+ kname << " winograd_conv_2d_fused_" << type_to_name (out) << " _flip"
557+ << conv_params.flip ;
558+ auto & compute_encoder = d.get_command_encoder (s.index );
559+ auto kernel = d.get_kernel (kname.str ());
560+ compute_encoder.set_compute_pipeline_state (kernel);
561+
562+ compute_encoder.set_input_array (in, 0 );
563+ compute_encoder.set_input_array (wt, 1 );
564+ compute_encoder.set_output_array (out, 2 );
565+
566+ compute_encoder.set_bytes (conv_params, 3 );
567+
568+ MTL::Size group_dims = MTL::Size (8 , 8 , 2 );
569+ MTL::Size grid_dims =
570+ MTL::Size (O_c / 8 , (N_tiles_h * N_tiles_w) / 8 , N_tiles_n);
571+
572+ compute_encoder.dispatch_threadgroups (grid_dims, group_dims);
573+ }
604574
575+ void winograd_conv_2D_gpu (
576+ const Stream& s,
577+ metal::Device& d,
578+ const array& in,
579+ const array& wt,
580+ array out,
581+ const MLXConvParams<2 >& conv_params,
582+ std::vector<array>& copies_w) {
605583 int O_c = conv_params.O ;
606584 int C_c = conv_params.C ;
607585
@@ -620,7 +598,7 @@ void winograd_conv_2D_gpu(
620598 int bo = 4 ;
621599 std::ostringstream kname;
622600 kname << " winograd_conv_2d_weight_transform_" << type_to_name (out) << " _bc"
623- << bc;
601+ << bc << " _flip " << conv_params. flip ;
624602 auto & compute_encoder = d.get_command_encoder (s.index );
625603 auto kernel = d.get_kernel (kname.str ());
626604 compute_encoder.set_compute_pipeline_state (kernel);
@@ -653,10 +631,10 @@ void winograd_conv_2D_gpu(
653631 auto kernel = d.get_kernel (kname.str ());
654632 compute_encoder.set_compute_pipeline_state (kernel);
655633
656- compute_encoder.set_input_array (in_padded , 0 );
634+ compute_encoder.set_input_array (in , 0 );
657635 compute_encoder.set_output_array (inp_wg, 1 );
658636
659- compute_encoder.set_bytes (conv_params_updated , 2 );
637+ compute_encoder.set_bytes (conv_params , 2 );
660638
661639 MTL::Size group_dims = MTL::Size (32 , wn, wm);
662640 MTL::Size grid_dims = MTL::Size (N_tiles_w, N_tiles_h, N_tiles_n);
@@ -703,7 +681,7 @@ void winograd_conv_2D_gpu(
703681 compute_encoder.set_input_array (out_wg, 0 );
704682 compute_encoder.set_output_array (out, 1 );
705683
706- compute_encoder.set_bytes (conv_params_updated , 2 );
684+ compute_encoder.set_bytes (conv_params , 2 );
707685
708686 MTL::Size group_dims = MTL::Size (32 , wn, wm);
709687 MTL::Size grid_dims = MTL::Size (N_tiles_w, N_tiles_h, N_tiles_n);
@@ -767,14 +745,18 @@ void conv_2D_gpu(
767745 }
768746
769747 // Direct to winograd conv
770- bool inp_large =
748+ bool img_large =
771749 (conv_params.N * conv_params.iS [0 ] * conv_params.iS [1 ]) >= 1ul << 12 ;
772750 bool channels_large = (conv_params.C + conv_params.O ) >= 256 ;
773- if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
774- conv_params.wS [0 ] == 3 && conv_params.wS [1 ] == 3 &&
775- conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
776- channels_large) {
777- return winograd_conv_2D_gpu (s, d, in, wt, out, conv_params, copies);
751+ if (conv_params.wS [0 ] == 3 && conv_params.wS [1 ] == 3 &&
752+ conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
753+ is_kdil_one && is_idil_one) {
754+ if (img_large && channels_large) {
755+ return winograd_conv_2D_gpu (s, d, in, wt, out, conv_params, copies);
756+ }
757+ if (conv_params.N <= 1 ) {
758+ return winograd_conv_2D_fused_gpu (s, d, in, wt, out, conv_params, copies);
759+ }
778760 }
779761
780762 // Direct to implicit gemm conv
@@ -876,8 +858,40 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
876858 wt = arr_copy;
877859 }
878860
861+ // Check for 1x1 conv
862+ auto is_one = [](int x) { return x == 1 ; };
863+ auto is_zero = [](int x) { return x == 0 ; };
864+ if (groups_ == 1 && (wt.shape (0 ) * wt.shape (-1 ) == wt.size ()) &&
865+ std::all_of (wt.shape ().begin () + 1 , wt.shape ().end () - 1 , is_one) &&
866+ std::all_of (kernel_strides_.begin (), kernel_strides_.end (), is_one) &&
867+ std::all_of (input_dilation_.begin (), input_dilation_.end (), is_one) &&
868+ std::all_of (kernel_dilation_.begin (), kernel_dilation_.end (), is_one) &&
869+ std::all_of (padding_.begin (), padding_.end (), is_zero)) {
870+ std::vector<array> empty_copies;
871+ steel_matmul_regular (
872+ s,
873+ d,
874+ /* a = */ in,
875+ /* b = */ wt,
876+ /* c = */ out,
877+ /* M = */ in.size () / in.shape (-1 ),
878+ /* N = */ wt.shape (0 ),
879+ /* K = */ in.shape (-1 ),
880+ /* batch_size_out = */ 1 ,
881+ /* lda = */ in.shape (-1 ),
882+ /* ldb = */ wt.shape (-1 ),
883+ /* ldd = */ wt.shape (0 ),
884+ /* transpose_a = */ false ,
885+ /* transpose_b = */ true ,
886+ /* batch_shape = */ {1 },
887+ /* batch_strides = */ {1 },
888+ /* A_batch_stride = */ 0 ,
889+ /* B_batch_stride = */ 0 ,
890+ /* matrix_stride_out = */ 0 ,
891+ /* copies = */ empty_copies);
892+ }
879893 // 3D conv
880- if (out.ndim () == 5 ) {
894+ else if (out.ndim () == 5 ) {
881895 conv_3D_gpu (
882896 s,
883897 d,
0 commit comments