Skip to content

Commit 71de73a

Browse files
authored
Fix convs by reverting #1803 (#1882)
1 parent 4c1dfa5 commit 71de73a

File tree

6 files changed

+96
-505
lines changed

6 files changed

+96
-505
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
2525
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
2626

2727
if(NOT MLX_VERSION)
28-
set(MLX_VERSION 0.23.0)
28+
set(MLX_VERSION 0.23.1)
2929
endif()
3030
add_compile_definitions("MLX_VERSION=${MLX_VERSION}")
3131

mlx/backend/metal/conv.cpp

Lines changed: 72 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -533,45 +533,6 @@ void implicit_gemm_conv_2D_general_gpu(
533533
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
534534
}
535535

536-
void winograd_conv_2D_fused_gpu(
537-
const Stream& s,
538-
metal::Device& d,
539-
const array& in,
540-
const array& wt,
541-
array out,
542-
const MLXConvParams<2>& conv_params,
543-
std::vector<array>& copies_w) {
544-
int O_c = conv_params.O;
545-
int C_c = conv_params.C;
546-
547-
int N_tiles_n = conv_params.N;
548-
int N_tiles_h = (conv_params.oS[0] + 1) / 2;
549-
int N_tiles_w = (conv_params.oS[1] + 1) / 2;
550-
int N_tiles = N_tiles_n * N_tiles_h * N_tiles_w;
551-
552-
int bc = 32;
553-
int wm = 4;
554-
int wn = 1;
555-
std::ostringstream kname;
556-
kname << "winograd_conv_2d_fused_" << type_to_name(out) << "_flip"
557-
<< conv_params.flip;
558-
auto& compute_encoder = d.get_command_encoder(s.index);
559-
auto kernel = d.get_kernel(kname.str());
560-
compute_encoder.set_compute_pipeline_state(kernel);
561-
562-
compute_encoder.set_input_array(in, 0);
563-
compute_encoder.set_input_array(wt, 1);
564-
compute_encoder.set_output_array(out, 2);
565-
566-
compute_encoder.set_bytes(conv_params, 3);
567-
568-
MTL::Size group_dims = MTL::Size(8, 8, 2);
569-
MTL::Size grid_dims =
570-
MTL::Size(O_c / 8, (N_tiles_h * N_tiles_w) / 8, N_tiles_n);
571-
572-
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
573-
}
574-
575536
void winograd_conv_2D_gpu(
576537
const Stream& s,
577538
metal::Device& d,
@@ -580,6 +541,67 @@ void winograd_conv_2D_gpu(
580541
array out,
581542
const MLXConvParams<2>& conv_params,
582543
std::vector<array>& copies_w) {
544+
Shape padded_shape = {
545+
conv_params.N,
546+
conv_params.iS[0] + 2 * conv_params.pad[0],
547+
conv_params.iS[1] + 2 * conv_params.pad[1],
548+
conv_params.C};
549+
550+
padded_shape[1] = 6 * ((padded_shape[1] - 2 + 5) / 6) + 2;
551+
padded_shape[2] = 6 * ((padded_shape[2] - 2 + 5) / 6) + 2;
552+
553+
array in_padded(std::move(padded_shape), in.dtype(), nullptr, {});
554+
555+
// Fill with zeros
556+
array zero_arr = array(0, in.dtype());
557+
fill_gpu(zero_arr, in_padded, s);
558+
copies_w.push_back(zero_arr);
559+
560+
// Pick input slice from padded
561+
size_t data_offset = conv_params.pad[0] * in_padded.strides()[1] +
562+
conv_params.pad[1] * in_padded.strides()[2];
563+
array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {});
564+
in_padded_slice.copy_shared_buffer(
565+
in_padded,
566+
in_padded.strides(),
567+
in_padded.flags(),
568+
in_padded_slice.size(),
569+
data_offset);
570+
571+
// Copy input values into the slice
572+
copy_gpu_inplace(in, in_padded_slice, CopyType::GeneralGeneral, s);
573+
574+
copies_w.push_back(in_padded_slice);
575+
copies_w.push_back(in_padded);
576+
577+
MLXConvParams<2> conv_params_updated{
578+
/* const int N = */ static_cast<int>(in_padded.shape(0)),
579+
/* const int C = */ static_cast<int>(in_padded.shape(3)),
580+
/* const int O = */ static_cast<int>(wt.shape(0)),
581+
/* const int iS[NDIM] = */
582+
{static_cast<int>(in_padded.shape(1)),
583+
static_cast<int>(in_padded.shape(2))},
584+
/* const int wS[NDIM] = */
585+
{static_cast<int>(wt.shape(1)), static_cast<int>(wt.shape(2))},
586+
/* const int oS[NDIM] = */
587+
{static_cast<int>(out.shape(1)), static_cast<int>(out.shape(2))},
588+
/* const int str[NDIM] = */ {1, 1},
589+
/* const int pad[NDIM] = */ {0, 0},
590+
/* const int kdil[NDIM] = */ {1, 1},
591+
/* const int idil[NDIM] = */ {1, 1},
592+
/* const size_t in_strides[NDIM + 2] = */
593+
{in_padded.strides()[0],
594+
in_padded.strides()[1],
595+
in_padded.strides()[2],
596+
in_padded.strides()[3]},
597+
/* const size_t wt_strides[NDIM + 2] = */
598+
{wt.strides()[0], wt.strides()[1], wt.strides()[2], wt.strides()[3]},
599+
/* const size_t out_strides[NDIM + 2] = */
600+
{out.strides()[0], out.strides()[1], out.strides()[2], out.strides()[3]},
601+
/* const int groups = */ 1,
602+
/* const bool flip = */ false,
603+
};
604+
583605
int O_c = conv_params.O;
584606
int C_c = conv_params.C;
585607

@@ -598,7 +620,7 @@ void winograd_conv_2D_gpu(
598620
int bo = 4;
599621
std::ostringstream kname;
600622
kname << "winograd_conv_2d_weight_transform_" << type_to_name(out) << "_bc"
601-
<< bc << "_flip" << conv_params.flip;
623+
<< bc;
602624
auto& compute_encoder = d.get_command_encoder(s.index);
603625
auto kernel = d.get_kernel(kname.str());
604626
compute_encoder.set_compute_pipeline_state(kernel);
@@ -631,10 +653,10 @@ void winograd_conv_2D_gpu(
631653
auto kernel = d.get_kernel(kname.str());
632654
compute_encoder.set_compute_pipeline_state(kernel);
633655

634-
compute_encoder.set_input_array(in, 0);
656+
compute_encoder.set_input_array(in_padded, 0);
635657
compute_encoder.set_output_array(inp_wg, 1);
636658

637-
compute_encoder.set_bytes(conv_params, 2);
659+
compute_encoder.set_bytes(conv_params_updated, 2);
638660

639661
MTL::Size group_dims = MTL::Size(32, wn, wm);
640662
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -681,7 +703,7 @@ void winograd_conv_2D_gpu(
681703
compute_encoder.set_input_array(out_wg, 0);
682704
compute_encoder.set_output_array(out, 1);
683705

684-
compute_encoder.set_bytes(conv_params, 2);
706+
compute_encoder.set_bytes(conv_params_updated, 2);
685707

686708
MTL::Size group_dims = MTL::Size(32, wn, wm);
687709
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
@@ -745,18 +767,14 @@ void conv_2D_gpu(
745767
}
746768

747769
// Direct to winograd conv
748-
bool img_large =
770+
bool inp_large =
749771
(conv_params.N * conv_params.iS[0] * conv_params.iS[1]) >= 1ul << 12;
750772
bool channels_large = (conv_params.C + conv_params.O) >= 256;
751-
if (conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
752-
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && is_stride_one &&
753-
is_kdil_one && is_idil_one) {
754-
if (img_large && channels_large) {
755-
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
756-
}
757-
if (conv_params.N <= 1) {
758-
return winograd_conv_2D_fused_gpu(s, d, in, wt, out, conv_params, copies);
759-
}
773+
if (!flip && is_stride_one && is_kdil_one && is_idil_one &&
774+
conv_params.wS[0] == 3 && conv_params.wS[1] == 3 &&
775+
conv_params.C % 32 == 0 && conv_params.O % 32 == 0 && inp_large &&
776+
channels_large) {
777+
return winograd_conv_2D_gpu(s, d, in, wt, out, conv_params, copies);
760778
}
761779

762780
// Direct to implicit gemm conv
@@ -858,40 +876,8 @@ void Convolution::eval_gpu(const std::vector<array>& inputs, array& out) {
858876
wt = arr_copy;
859877
}
860878

861-
// Check for 1x1 conv
862-
auto is_one = [](int x) { return x == 1; };
863-
auto is_zero = [](int x) { return x == 0; };
864-
if (groups_ == 1 && (wt.shape(0) * wt.shape(-1) == wt.size()) &&
865-
std::all_of(wt.shape().begin() + 1, wt.shape().end() - 1, is_one) &&
866-
std::all_of(kernel_strides_.begin(), kernel_strides_.end(), is_one) &&
867-
std::all_of(input_dilation_.begin(), input_dilation_.end(), is_one) &&
868-
std::all_of(kernel_dilation_.begin(), kernel_dilation_.end(), is_one) &&
869-
std::all_of(padding_.begin(), padding_.end(), is_zero)) {
870-
std::vector<array> empty_copies;
871-
steel_matmul_regular(
872-
s,
873-
d,
874-
/*a = */ in,
875-
/*b = */ wt,
876-
/*c = */ out,
877-
/*M = */ in.size() / in.shape(-1),
878-
/*N = */ wt.shape(0),
879-
/*K = */ in.shape(-1),
880-
/*batch_size_out = */ 1,
881-
/*lda = */ in.shape(-1),
882-
/*ldb = */ wt.shape(-1),
883-
/*ldd = */ wt.shape(0),
884-
/*transpose_a = */ false,
885-
/*transpose_b = */ true,
886-
/*batch_shape = */ {1},
887-
/*batch_strides = */ {1},
888-
/*A_batch_stride = */ 0,
889-
/*B_batch_stride = */ 0,
890-
/*matrix_stride_out = */ 0,
891-
/*copies = */ empty_copies);
892-
}
893879
// 3D conv
894-
else if (out.ndim() == 5) {
880+
if (out.ndim() == 5) {
895881
conv_3D_gpu(
896882
s,
897883
d,

0 commit comments

Comments
 (0)