Skip to content

Commit 7e26fd8

Browse files
authored
Option to JIT steel gemm / conv (#1139)
1 parent eab2685 commit 7e26fd8

31 files changed

+2504
-1540
lines changed

mlx/backend/metal/CMakeLists.txt

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function(make_jit_source SRC_NAME)
1+
function(make_jit_source SRC_FILE)
22
# This function takes a metal header file,
33
# runs the C preprocessesor on it, and makes
44
# the processed contents available as a string in a C++ function
@@ -9,17 +9,18 @@ function(make_jit_source SRC_NAME)
99
#
1010
# Additional arguments to this function are treated as dependencies
1111
# in the Cmake build system.
12+
get_filename_component(SRC_NAME ${SRC_FILE} NAME)
1213
add_custom_command(
1314
OUTPUT jit/${SRC_NAME}.cpp
1415
COMMAND /bin/bash
1516
${CMAKE_CURRENT_SOURCE_DIR}/make_compiled_preamble.sh
1617
${CMAKE_CURRENT_BINARY_DIR}/jit
1718
${CMAKE_C_COMPILER}
1819
${PROJECT_SOURCE_DIR}
19-
${SRC_NAME}
20+
${SRC_FILE}
2021
"-D${MLX_METAL_VERSION}"
2122
DEPENDS make_compiled_preamble.sh
22-
kernels/${SRC_NAME}.h
23+
kernels/${SRC_FILE}.h
2324
${ARGN}
2425
)
2526
add_custom_target(${SRC_NAME} DEPENDS jit/${SRC_NAME}.cpp)
@@ -73,6 +74,39 @@ if (MLX_METAL_JIT)
7374
kernels/reduction/reduce_col.h
7475
kernels/reduction/reduce_row.h
7576
)
77+
make_jit_source(
78+
steel/gemm/gemm
79+
kernels/steel/utils.h
80+
kernels/steel/gemm/loader.h
81+
kernels/steel/gemm/mma.h
82+
kernels/steel/gemm/params.h
83+
kernels/steel/gemm/transforms.h
84+
)
85+
make_jit_source(steel/gemm/kernels/steel_gemm_fused)
86+
make_jit_source(
87+
steel/gemm/kernels/steel_gemm_masked
88+
kernels/steel/defines.h
89+
)
90+
make_jit_source(steel/gemm/kernels/steel_gemm_splitk)
91+
make_jit_source(
92+
steel/conv/conv
93+
kernels/steel/utils.h
94+
kernels/steel/defines.h
95+
kernels/steel/gemm/mma.h
96+
kernels/steel/gemm/transforms.h
97+
kernels/steel/conv/params.h
98+
kernels/steel/conv/loader.h
99+
kernels/steel/conv/loaders/loader_channel_l.h
100+
kernels/steel/conv/loaders/loader_channel_n.h
101+
)
102+
make_jit_source(
103+
steel/conv/kernels/steel_conv
104+
)
105+
make_jit_source(
106+
steel/conv/kernels/steel_conv_general
107+
kernels/steel/defines.h
108+
kernels/steel/conv/loaders/loader_general.h
109+
)
76110
else()
77111
target_sources(
78112
mlx

mlx/backend/metal/conv.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include "mlx/backend/metal/copy.h"
99
#include "mlx/backend/metal/device.h"
10+
#include "mlx/backend/metal/kernels.h"
1011
#include "mlx/backend/metal/kernels/defines.h"
1112
#include "mlx/backend/metal/kernels/steel/conv/params.h"
1213
#include "mlx/backend/metal/matmul.h"
@@ -335,7 +336,17 @@ void implicit_gemm_conv_2D_gpu(
335336

336337
// Encode and dispatch kernel
337338
auto& compute_encoder = d.get_command_encoder(s.index);
338-
auto kernel = d.get_kernel(kname.str());
339+
auto kernel = get_steel_conv_kernel(
340+
d,
341+
kname.str(),
342+
out,
343+
bm,
344+
bn,
345+
bk,
346+
wm,
347+
wn,
348+
n_channel_specialization,
349+
small_filter);
339350
compute_encoder->setComputePipelineState(kernel);
340351

341352
// Deduce grid launch dimensions
@@ -488,7 +499,8 @@ void implicit_gemm_conv_2D_general_gpu(
488499

489500
// Encode and dispatch kernel
490501
auto& compute_encoder = d.get_command_encoder(s.index);
491-
auto kernel = d.get_kernel(kname.str());
502+
auto kernel =
503+
get_steel_conv_general_kernel(d, kname.str(), out, bm, bn, bk, wm, wn);
492504
compute_encoder->setComputePipelineState(kernel);
493505

494506
// Deduce grid launch dimensions

mlx/backend/metal/jit/includes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,12 @@ const char* softmax();
2323
const char* sort();
2424
const char* reduce();
2525

26+
const char* gemm();
27+
const char* steel_gemm_fused();
28+
const char* steel_gemm_masked();
29+
const char* steel_gemm_splitk();
30+
const char* conv();
31+
const char* steel_conv();
32+
const char* steel_conv_general();
33+
2634
} // namespace mlx::core::metal

mlx/backend/metal/jit/steel_conv.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
constexpr std::string_view steel_conv_kernels = R"(
4+
template [[host_name("{name}")]] [[kernel]] void
5+
implicit_gemm_conv_2d<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {n_channels}, {small_filter}>(
6+
const device {itype}* A [[buffer(0)]],
7+
const device {itype}* B [[buffer(1)]],
8+
device {itype}* C [[buffer(2)]],
9+
const constant MLXConvParams<2>* params [[buffer(3)]],
10+
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
11+
uint3 tid [[threadgroup_position_in_grid]],
12+
uint3 lid [[thread_position_in_threadgroup]],
13+
uint simd_gid [[simdgroup_index_in_threadgroup]],
14+
uint simd_lid [[thread_index_in_simdgroup]]);
15+
)";
16+
17+
constexpr std::string_view steel_conv_general_kernels = R"(
18+
template [[host_name("{name}")]] [[kernel]] void
19+
implicit_gemm_conv_2d_general<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}>(
20+
const device {itype}* A [[buffer(0)]],
21+
const device {itype}* B [[buffer(1)]],
22+
device {itype}* C [[buffer(2)]],
23+
const constant MLXConvParams<2>* params [[buffer(3)]],
24+
const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]],
25+
const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]],
26+
const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]],
27+
const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]],
28+
uint3 tid [[threadgroup_position_in_grid]],
29+
uint3 lid [[thread_position_in_threadgroup]],
30+
uint simd_gid [[simdgroup_index_in_threadgroup]],
31+
uint simd_lid [[thread_index_in_simdgroup]]);
32+
)";

mlx/backend/metal/jit/steel_gemm.h

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
constexpr std::string_view steel_gemm_fused_kernels = R"(
4+
template [[host_name("{name}")]]
5+
[[kernel]] void gemm<{itype}, {bm}, {bn}, {bk}, {wm}, {wn}, {trans_a}, {trans_b}, float>(
6+
const device {itype} *A [[buffer(0)]],
7+
const device {itype} *B [[buffer(1)]],
8+
const device {itype} *C [[buffer(2), function_constant(use_out_source)]],
9+
device {itype} *D [[buffer(3)]],
10+
const constant GEMMParams* params [[buffer(4)]],
11+
const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]],
12+
const constant int* batch_shape [[buffer(6)]],
13+
const constant size_t* batch_strides [[buffer(7)]],
14+
const constant uint32_t* lhs_indices [[buffer(10), function_constant(do_gather)]],
15+
const constant uint32_t* rhs_indices [[buffer(11), function_constant(do_gather)]],
16+
const constant uint32_t* C_indices [[buffer(12), function_constant(gather_bias)]],
17+
const constant int* operand_shape [[buffer(13), function_constant(do_gather)]],
18+
const constant size_t* operand_strides [[buffer(14), function_constant(do_gather)]],
19+
const constant packed_int3& operand_batch_ndim [[buffer(15), function_constant(do_gather)]],
20+
uint simd_lane_id [[thread_index_in_simdgroup]],
21+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
22+
uint3 tid [[threadgroup_position_in_grid]],
23+
uint3 lid [[thread_position_in_threadgroup]]);
24+
)";
25+
26+
constexpr std::string_view steel_gemm_masked_kernels = R"(
27+
template [[host_name("{name}")]] [[kernel]] void
28+
block_masked_gemm<
29+
{itype},
30+
{outmasktype},
31+
{opmasktype},
32+
{bm},
33+
{bn},
34+
{bk},
35+
{wm},
36+
{wn},
37+
{trans_a},
38+
{trans_b},
39+
{mn_aligned},
40+
{k_aligned}>(
41+
const device {itype}* A [[buffer(0)]],
42+
const device {itype}* B [[buffer(1)]],
43+
device {itype}* D [[buffer(3)]],
44+
const constant GEMMParams* params [[buffer(4)]],
45+
const constant int* batch_shape [[buffer(6)]],
46+
const constant size_t* batch_strides [[buffer(7)]],
47+
const device {outmasktype}* out_mask [[buffer(10)]],
48+
const device {opmasktype}* lhs_mask [[buffer(11)]],
49+
const device {opmasktype}* rhs_mask [[buffer(12)]],
50+
const constant int* mask_strides [[buffer(13)]],
51+
uint simd_lane_id [[thread_index_in_simdgroup]],
52+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
53+
uint3 tid [[threadgroup_position_in_grid]],
54+
uint3 lid [[thread_position_in_threadgroup]]);
55+
)";
56+
57+
constexpr std::string_view steel_gemm_splitk_kernels = R"(
58+
template [[host_name("{name}")]] [[kernel]] void
59+
gemm_splitk<
60+
{itype},
61+
{otype},
62+
{bm},
63+
{bn},
64+
{bk},
65+
{wm},
66+
{wn},
67+
{trans_a},
68+
{trans_b},
69+
{mn_aligned},
70+
{k_aligned}>(
71+
const device {itype}* A [[buffer(0)]],
72+
const device {itype}* B [[buffer(1)]],
73+
device {otype}* C [[buffer(2)]],
74+
const constant GEMMSpiltKParams* params [[buffer(3)]],
75+
uint simd_lane_id [[thread_index_in_simdgroup]],
76+
uint simd_group_id [[simdgroup_index_in_threadgroup]],
77+
uint3 tid [[threadgroup_position_in_grid]],
78+
uint3 lid [[thread_position_in_threadgroup]]);
79+
)";
80+
81+
constexpr std::string_view steel_gemm_splitk_accum_kernels = R"(
82+
template [[host_name("{name}")]] [[kernel]] void
83+
gemm_splitk_accum<{atype}, {otype}>(
84+
const device {atype}* C_split [[buffer(0)]],
85+
device {otype}* D [[buffer(1)]],
86+
const constant int& k_partitions [[buffer(2)]],
87+
const constant int& partition_stride [[buffer(3)]],
88+
const constant int& ldd [[buffer(4)]],
89+
uint2 gid [[thread_position_in_grid]]);
90+
)";
91+
92+
constexpr std::string_view steel_gemm_splitk_accum_axbpy_kernels = R"(
93+
template [[host_name("{name}")]] [[kernel]] void
94+
gemm_splitk_accum_axpby<{atype}, {otype}>(
95+
const device {atype}* C_split [[buffer(0)]],
96+
device {otype}* D [[buffer(1)]],
97+
const constant int& k_partitions [[buffer(2)]],
98+
const constant int& partition_stride [[buffer(3)]],
99+
const constant int& ldd [[buffer(4)]],
100+
const device {otype}* C [[buffer(5)]],
101+
const constant int& ldc [[buffer(6)]],
102+
const constant int& fdc [[buffer(7)]],
103+
const constant float& alpha [[buffer(8)]],
104+
const constant float& beta [[buffer(9)]],
105+
uint2 gid [[thread_position_in_grid]]);
106+
)";

0 commit comments

Comments
 (0)