Skip to content

Commit 2dc307f

Browse files
authored
Winograd Update for Small batches (#1803)
* Build in padding to Winograd kernels * Add new fused Winograd kernel * Enable weight flipping in Winograd kernels
1 parent 7aea5b1 commit 2dc307f

File tree

4 files changed

+505
-86
lines changed

4 files changed

+505
-86
lines changed

mlx/backend/metal/conv.cpp

Lines changed: 86 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -533,75 +533,53 @@ void implicit_gemm_conv_2D_general_gpu(
533533
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
534534
}
535535

536-
void winograd_conv_2D_gpu(
536+
void winograd_conv_2D_fused_gpu(
537537
const Stream& s,
538538
metal::Device& d,
539539
const array& in,
540540
const array& wt,
541541
array out,
542542
const MLXConvParams<2>& conv_params,
543543
std::vector<array>& copies_w) {
544-
Shape padded_shape = {
545-
conv_params.N,
546-
conv_params.iS[0] + 2 * conv_params.pad[0],
547-
conv_params.iS[1] + 2 * conv_params.pad[1],
548-
conv_params.C};
549-
550-
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
551-
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
552-
553-
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
554-
555-
// Fill with zeros
556-
array zero_arr = array(0, in.dtype());
557-
fill_gpu(zero_arr, in_padded, s);
558-
copies_w.push_back(zero_arr);
559-
560-
// Pick input slice from padded
561-
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
562-
conv_params.pad[1] * in_padded.strides()[2];
563-
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
564-
in_padded_slice.copy_shared_buffer(
565-
in_padded,
566-
in_padded.strides(),
567-
in_padded.flags(),
568-
in_padded_slice.size(),
569-
data_offset);
570-
571-
// Copy input values into the slice
572-
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
573-
574-
copies_w.push_back(in_padded_slice);
575-
copies_w.push_back(in_padded);
576-
577-
MLXConvParams<2> conv_params_updated{
578-
/* const int N = */ static_cast<int>(in_padded.shape(0)),
579-
/* const int C = */ static_cast<int>(in_padded.shape(3)),
580-
/* const int O = */ static_cast<int>(wt.shape(0)),
581-
/* const int iS[NDIM] = */
582-
{static_cast<int>(in_padded.shape(1)),
583-
static_cast<int>(in_padded.shape(2))},
584-
/* const int wS[NDIM] = */
585-
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
586-
/* const int oS[NDIM] = */
587-
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
588-
/* const int str[NDIM] = */ {1, 1},
589-
/* const int pad[NDIM] = */ {0, 0},
590-
/* const int kdil[NDIM] = */ {1, 1},
591-
/* const int idil[NDIM] = */ {1, 1},
592-
/* const size_t in_strides[NDIM + 2] = */
593-
{in_padded.strides()[0],
594-
in_padded.strides()[1],
595-
in_padded.strides()[2],
596-
in_padded.strides()[3]},
597-
/* const size_t wt_strides[NDIM + 2] = */
598-
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
599-
/* const size_t out_strides[NDIM + 2] = */
600-
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
601-
/* const int groups = */ 1,
602-
/* const bool flip = */ false,
603-
};
544+
int O_c = conv_params.O;
545+
int C_c = conv_params.C;
546+
547+
int N_tiles_n = conv_params.N;
548+
int N_tiles_h = (conv_params.oS[0] + 1) / 2;
549+
int N_tiles_w = (conv_params.oS[1] + 1) / 2;
550+
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
551+
552+
int bc = 32;
553+
int wm = 4;
554+
int wn = 1;
555+
std::ostringstream kname;
556+
kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip"
557+
<< conv_params.flip;
558+
auto& compute_encoder = d.get_command_encoder(s.index);
559+
auto kernel = d.get_kernel(kname.str());
560+
compute_encoder.set_compute_pipeline_state(kernel);
561+
562+
compute_encoder.set_input_array(in, 0);
563+
compute_encoder.set_input_array(wt, 1);
564+
compute_encoder.set_output_array(out, 2);
565+
566+
compute_encoder.set_bytes(conv_params, 3);
567+
568+
MTL::Size group_dims = MTL::Size(8, 8, 2);
569+
MTL::Size grid_dims =
570+
MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n);
571+
572+
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
573+
}
604574

575+
void winograd_conv_2D_gpu(
576+
const Stream& s,
577+
metal::Device& d,
578+
const array& in,
579+
const array& wt,
580+
array out,
581+
const MLXConvParams<2>& conv_params,
582+
std::vector<array>& copies_w) {
605583
int O_c = conv_params.O;
606584
int C_c = conv_params.C;
607585

@@ -620,7 +598,7 @@ void winograd_conv_2D_gpu(
620598
int bo = 4;
621599
std::ostringstream kname;
622600
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
623-
<< bc;
601+
<< bc << "_flip" << conv_params.flip;
624602
auto& compute_encoder = d.get_command_encoder(s.index);
625603
auto kernel = d.get_kernel(kname.str());
626604
compute_encoder.set_compute_pipeline_state(kernel);
@@ -653,10 +631,10 @@ void winograd_conv_2D_gpu(
653631
auto kernel = d.get_kernel(kname.str());
654632
compute_encoder.set_compute_pipeline_state(kernel);
655633

656-
compute_encoder.set_input_array(in_padded, 0);
634+
compute_encoder.set_input_array(in, 0);
657635
compute_encoder.set_output_array(inp_wg, 1);
658636

659-
compute_encoder.set_bytes(conv_params_updated, 2);
637+
compute_encoder.set_bytes(conv_params, 2);
660638

661639
MTL::Size group_dims = MTL::Size(32, wn, wm);
662640
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -703,7 +681,7 @@ void winograd_conv_2D_gpu(
703681
compute_encoder.set_input_array(out_wg, 0);
704682
compute_encoder.set_output_array(out, 1);
705683

706-
compute_encoder.set_bytes(conv_params_updated, 2);
684+
compute_encoder.set_bytes(conv_params, 2);
707685

708686
MTL::Size group_dims = MTL::Size(32, wn, wm);
709687
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -767,14 +745,18 @@ void conv_2D_gpu(
767745
}
768746

769747
// Direct to winograd conv
770-
bool inp_large =
748+
bool img_large =
771749
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
772750
bool channels_large = (conv_params.C + conv_params.O) >= 256;
773-
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
774-
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
775-
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
776-
channels_large) {
777-
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
751+
if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
752+
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
753+
is_kdil_one && is_idil_one) {
754+
if (img_large && channels_large) {
755+
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
756+
}
757+
if (conv_params.N <= 1) {
758+
return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies);
759+
}
778760
}
779761

780762
// Direct to implicit gemm conv
@@ -876,8 +858,40 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
876858
wt = arr_copy;
877859
}
878860

861+
// Check for 1x1 conv
862+
auto is_one = [](int x) { return x == 1; };
863+
auto is_zero = [](int x) { return x == 0; };
864+
if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) &&
865+
std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) &&
866+
std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) &&
867+
std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) &&
868+
std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) &&
869+
std::all_of(padding_.begin(), padding_.end(), is_zero)) {
870+
std::vector<array> empty_copies;
871+
steel_matmul_regular(
872+
s,
873+
d,
874+
/*a = */ in,
875+
/*b = */ wt,
876+
/*c = */ out,
877+
/*M = */ in.size() / in.shape(-1),
878+
/*N = */ wt.shape(0),
879+
/*K = */ in.shape(-1),
880+
/*batch_size_out = */ 1,
881+
/*lda = */ in.shape(-1),
882+
/*ldb = */ wt.shape(-1),
883+
/*ldd = */ wt.shape(0),
884+
/*transpose_a = */ false,
885+
/*transpose_b = */ true,
886+
/*batch_shape = */ {1},
887+
/*batch_strides = */ {1},
888+
/*A_batch_stride = */ 0,
889+
/*B_batch_stride = */ 0,
890+
/*matrix_stride_out = */ 0,
891+
/*copies = */ empty_copies);
892+
}
879893
// 3D conv
880-
if (out.ndim() == 5) {
894+
else if (out.ndim() == 5) {
881895
conv_3D_gpu(
882896
s,
883897
d,

0 commit comments

Comments
 (0)