Skip to content

Commit 960e3f0

Browse files
authored
Gemm update (#1518)
1 parent 884af42 commit 960e3f0

File tree

9 files changed

+699
-194
lines changed

9 files changed

+699
-194
lines changed

mlx/backend/metal/device.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ Device::Device() {
181181
auto pool = new_scoped_memory_pool();
182182
device_ = load_device();
183183
library_map_ = {{"mlx", load_library(device_)}};
184+
arch_ = std::string(device_->architecture()->name()->utf8String());
184185
}
185186

186187
Device::~Device() {

mlx/backend/metal/device.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ class Device {
136136
return device_;
137137
};
138138

139+
const std::string& get_architecture() {
140+
return arch_;
141+
}
142+
139143
void new_queue(int index);
140144
MTL::CommandBuffer* get_command_buffer(int index);
141145
int get_command_buffer_ops(int index);
@@ -228,6 +232,7 @@ class Device {
228232
std::shared_mutex library_mtx_;
229233
std::unordered_map<std::string, MTL::Library*> library_map_;
230234
const MTL::ResidencySet* residency_set_{nullptr};
235+
std::string arch_;
231236
};
232237

233238
Device& device(mlx::core::Device);

mlx/backend/metal/kernels/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ set(STEEL_HEADERS
5050
steel/gemm/transforms.h
5151
steel/gemm/kernels/steel_gemm_fused.h
5252
steel/gemm/kernels/steel_gemm_masked.h
53-
steel/gemm/kernels/steel_gemm_splitk.h)
53+
steel/gemm/kernels/steel_gemm_splitk.h
54+
steel/utils/type_traits.h
55+
steel/utils/integral_constant.h)
5456

5557
if(NOT MLX_METAL_JIT)
5658
build_kernel(arange arange.h)

mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general(
142142
// Store results to device memory
143143
{
144144
// Adjust for simdgroup and thread locatio
145-
int offset_m = c_row + mma_op.sm + mma_op.tm;
146-
int offset_n = c_col + mma_op.sn + mma_op.tn;
145+
int offset_m = c_row + mma_op.sm;
146+
int offset_n = c_col + mma_op.sn;
147147
C += offset_n;
148148

149149
if (offset_n >= gemm_params->N)
@@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general(
169169
STEEL_PRAGMA_UNROLL
170170
for (int j = 0; j < mma_t::TN; j++) {
171171
// Get accumulated result and associated offset in C
172-
thread const auto& accum =
173-
mma_op.results[i * mma_t::TN + j].thread_elements();
172+
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
174173
int offset = offset_cm + (j * mma_t::TN_stride);
175174

176-
// Apply epilogue and output C
177-
if (j * mma_t::TN_stride < diff) {
178-
C[offset] = Epilogue::apply(accum[0]);
179-
}
175+
constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;
180176

181-
if (j * mma_t::TN_stride + 1 < diff) {
182-
C[offset + 1] = Epilogue::apply(accum[1]);
177+
// Apply epilogue and output C
178+
STEEL_PRAGMA_UNROLL
179+
for (short k = 0; k < kelems; k++) {
180+
if ((j * mma_t::TN_stride + k) < diff) {
181+
C[offset + k] = Epilogue::apply(accum[k]);
182+
}
183183
}
184184
}
185185
}

mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)
3737

3838
#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
39-
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
4039
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
40+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
4141
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
42-
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
43-
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
42+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
43+
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)
4444

4545
instantiate_gemm_shapes_helper(float16, half, float16, half);
4646
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);

0 commit comments

Comments
 (0)