Skip to content

Commit 06375e6

Browse files
authored
Split encoders in non-concurrent context with a max ops per encoder (#1085)
* split encoders * fix race
1 parent b21242f commit 06375e6

18 files changed

+148
-136
lines changed

mlx/backend/metal/compiled.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ void Compiled::eval_gpu(
336336
MTL::Size grid_dims(nthreads, 1, 1);
337337
MTL::Size group_dims(
338338
std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1);
339-
compute_encoder->dispatchThreads(grid_dims, group_dims);
339+
compute_encoder.dispatchThreads(grid_dims, group_dims);
340340
} else {
341341
size_t dim0 = ndim > 0 ? shape[ndim - 1] : 1;
342342
size_t dim1 = ndim > 1 ? shape[ndim - 2] : 1;
@@ -347,7 +347,7 @@ void Compiled::eval_gpu(
347347
}
348348
auto group_dims = get_block_dims(dim0, dim1, rest);
349349
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
350-
compute_encoder->dispatchThreads(grid_dims, group_dims);
350+
compute_encoder.dispatchThreads(grid_dims, group_dims);
351351
}
352352
}
353353

mlx/backend/metal/conv.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void explicit_gemm_conv_ND_gpu(
5959
MTL::Size grid_dims = MTL::Size(
6060
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
6161

62-
compute_encoder->dispatchThreads(grid_dims, group_dims);
62+
compute_encoder.dispatchThreads(grid_dims, group_dims);
6363

6464
// Reshape weight
6565
std::vector<int> wt_reshape{implicit_K, implicit_N};
@@ -137,7 +137,7 @@ void explicit_gemm_conv_group_ND_gpu(
137137
MTL::Size grid_dims = MTL::Size(
138138
conv_params.C, unfolded_shape[1] / conv_params.C, unfolded_shape[0]);
139139

140-
compute_encoder->dispatchThreads(grid_dims, group_dims);
140+
compute_encoder.dispatchThreads(grid_dims, group_dims);
141141

142142
// Transpose kernel weights so that we can slice them by contiguous chunks
143143
// of channel groups.
@@ -247,7 +247,7 @@ void slow_conv_2D_gpu(
247247
compute_encoder.set_output_array(out, 2);
248248

249249
compute_encoder->setBytes(&conv_params, sizeof(MLXConvParams<2>), 3);
250-
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
250+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
251251
}
252252

253253
void implicit_gemm_conv_2D_gpu(
@@ -352,7 +352,7 @@ void implicit_gemm_conv_2D_gpu(
352352
compute_encoder->setBytes(&gemm_params, sizeof(ImplicitGemmConv2DParams), 4);
353353

354354
// Launch kernel
355-
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
355+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
356356
}
357357

358358
void implicit_gemm_conv_2D_general_gpu(
@@ -512,7 +512,7 @@ void implicit_gemm_conv_2D_general_gpu(
512512
base_w.data(), sizeof(Conv2DGeneralBaseInfo) * base_w.size(), 7);
513513

514514
// Launch kernel
515-
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
515+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
516516
}
517517

518518
void winograd_conv_2D_gpu(
@@ -613,7 +613,7 @@ void winograd_conv_2D_gpu(
613613
MTL::Size group_dims = MTL::Size(32, bo, 1);
614614
MTL::Size grid_dims = MTL::Size(O_c / bo, 1, 1);
615615

616-
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
616+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
617617
}
618618

619619
// Do input transform
@@ -641,7 +641,7 @@ void winograd_conv_2D_gpu(
641641
MTL::Size group_dims = MTL::Size(32, wn, wm);
642642
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
643643

644-
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
644+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
645645
}
646646

647647
// Do batched gemm
@@ -689,7 +689,7 @@ void winograd_conv_2D_gpu(
689689
MTL::Size group_dims = MTL::Size(32, wn, wm);
690690
MTL::Size grid_dims = MTL::Size(N_tiles_w, N_tiles_h, N_tiles_n);
691691

692-
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);
692+
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
693693
}
694694
}
695695

mlx/backend/metal/copy.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void copy_gpu_inplace(
126126

127127
auto group_dims = get_block_dims(dim0, dim1, rest);
128128
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
129-
compute_encoder->dispatchThreads(grid_dims, group_dims);
129+
compute_encoder.dispatchThreads(grid_dims, group_dims);
130130
} else {
131131
size_t nthreads = out.data_size();
132132
MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
@@ -135,7 +135,7 @@ void copy_gpu_inplace(
135135
thread_group_size = nthreads;
136136
}
137137
MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
138-
compute_encoder->dispatchThreads(grid_dims, group_dims);
138+
compute_encoder.dispatchThreads(grid_dims, group_dims);
139139
}
140140
}
141141

mlx/backend/metal/device.cpp

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ namespace {
2525

2626
// TODO nicer way to set this or possibly expose as an environment variable
2727
constexpr int MAX_BUFFERS_PER_QUEUE = 12;
28+
constexpr int MAX_DISPATCHES_PER_ENCODER = 2;
2829

2930
constexpr const char* default_mtllib_path = METAL_PATH;
3031

@@ -37,7 +38,6 @@ auto load_device() {
3738
}
3839
return device;
3940
}
40-
4141
std::pair<MTL::Library*, NS::Error*> load_library_from_path(
4242
MTL::Device* device,
4343
const char* path) {
@@ -116,6 +116,33 @@ MTL::Library* load_library(
116116

117117
} // namespace
118118

119+
void CommandEncoder::dispatchThreadgroups(
120+
MTL::Size grid_dims,
121+
MTL::Size group_dims) {
122+
num_dispatches++;
123+
enc->dispatchThreadgroups(grid_dims, group_dims);
124+
maybe_split();
125+
}
126+
127+
void CommandEncoder::dispatchThreads(
128+
MTL::Size grid_dims,
129+
MTL::Size group_dims) {
130+
num_dispatches++;
131+
enc->dispatchThreads(grid_dims, group_dims);
132+
maybe_split();
133+
}
134+
135+
void CommandEncoder::maybe_split() {
136+
if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
137+
enc->endEncoding();
138+
enc->release();
139+
num_dispatches = 0;
140+
outputs.clear();
141+
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
142+
enc->retain();
143+
}
144+
}
145+
119146
Device::Device() {
120147
auto pool = new_scoped_memory_pool();
121148
device_ = load_device();
@@ -130,9 +157,6 @@ Device::~Device() {
130157
for (auto& b : buffer_map_) {
131158
b.second.second->release();
132159
}
133-
for (auto& e : encoder_map_) {
134-
(*e.second)->release();
135-
}
136160
for (auto& k : kernel_map_) {
137161
k.second->release();
138162
}
@@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) {
169193

170194
MTL::CommandBuffer* Device::get_command_buffer(int index) {
171195
auto bit = buffer_map_.find(index);
172-
return (bit == buffer_map_.end()) ? nullptr : bit->second.second;
173-
}
174-
175-
MTL::CommandBuffer* Device::new_command_buffer(int index) {
176-
auto qit = queue_map_.find(index);
177-
if (qit == queue_map_.end()) {
178-
throw std::runtime_error(
179-
"[metal::Device] Attempting to get command buffer for invalid queue.");
180-
}
196+
if (bit == buffer_map_.end()) {
197+
auto qit = queue_map_.find(index);
198+
if (qit == queue_map_.end()) {
199+
throw std::runtime_error(
200+
"[metal::Device] Attempting to get command buffer for invalid queue.");
201+
}
181202

182-
auto cb = qit->second->commandBufferWithUnretainedReferences();
203+
auto cb = qit->second->commandBufferWithUnretainedReferences();
183204

184-
if (!cb) {
185-
throw std::runtime_error(
186-
"[metal::Device] Unable to create new command buffer");
187-
}
205+
if (!cb) {
206+
throw std::runtime_error(
207+
"[metal::Device] Unable to create new command buffer");
208+
}
188209

189-
// Increment ref count so the buffer is not garbage collected
190-
cb->retain();
210+
// Increment ref count so the buffer is not garbage collected
211+
cb->retain();
191212

192-
return buffer_map_.insert({index, {0, cb}}).first->second.second;
213+
bit = buffer_map_.insert({index, {0, cb}}).first;
214+
}
215+
return bit->second.second;
193216
}
194217

195218
void Device::commit_command_buffer(int index) {
@@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) {
200223
}
201224

202225
void Device::end_encoding(int index) {
203-
auto eit = encoder_map_.find(index);
204-
if (eit != encoder_map_.end()) {
205-
(*eit->second)->endEncoding();
206-
(*eit->second)->release();
207-
encoder_map_.erase(eit);
208-
}
226+
encoder_map_.erase(index);
209227
}
210228

211229
CommandEncoder& Device::get_command_encoder(int index) {
212230
auto eit = encoder_map_.find(index);
213231
if (eit == encoder_map_.end()) {
214232
auto cb = get_command_buffer(index);
215-
auto compute_encoder =
216-
cb->computeCommandEncoder(MTL::DispatchTypeConcurrent);
217-
// Increment ref count so the buffer is not garbage collected
218-
compute_encoder->retain();
219-
eit = encoder_map_
220-
.emplace(index, std::make_unique<CommandEncoder>(compute_encoder))
221-
.first;
233+
eit =
234+
encoder_map_.emplace(index, std::make_unique<CommandEncoder>(cb)).first;
222235
}
223236
return *(eit->second);
224237
}

mlx/backend/metal/device.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,10 @@ using MTLFCList =
3737
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
3838

3939
struct CommandEncoder {
40-
CommandEncoder(MTL::ComputeCommandEncoder* enc)
41-
: enc(enc), concurrent(false) {};
40+
CommandEncoder(MTL::CommandBuffer* cbuf) : cbuf(cbuf) {
41+
enc = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
42+
enc->retain();
43+
};
4244
CommandEncoder(const CommandEncoder&) = delete;
4345
CommandEncoder& operator=(const CommandEncoder&) = delete;
4446

@@ -89,13 +91,25 @@ struct CommandEncoder {
8991
}
9092
}
9193

94+
void dispatchThreadgroups(MTL::Size grid_dims, MTL::Size group_dims);
95+
void dispatchThreads(MTL::Size grid_dims, MTL::Size group_dims);
96+
9297
ConcurrentContext start_concurrent() {
9398
return ConcurrentContext(*this);
9499
}
95100

101+
~CommandEncoder() {
102+
enc->endEncoding();
103+
enc->release();
104+
}
105+
96106
private:
107+
void maybe_split();
108+
109+
int num_dispatches{0};
110+
MTL::CommandBuffer* cbuf;
97111
MTL::ComputeCommandEncoder* enc;
98-
bool concurrent;
112+
bool concurrent{false};
99113
std::unordered_set<MTL::Resource*> outputs;
100114
std::unordered_set<MTL::Resource*> concurrent_outputs;
101115
};
@@ -112,7 +126,6 @@ class Device {
112126
};
113127

114128
void new_queue(int index);
115-
MTL::CommandBuffer* new_command_buffer(int index);
116129
MTL::CommandBuffer* get_command_buffer(int index);
117130
int get_command_buffer_ops(int index);
118131
void increment_command_buffer_ops(int index);

mlx/backend/metal/fft.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
9797

9898
auto group_dims = MTL::Size(1, m, 1);
9999
auto grid_dims = MTL::Size(batch, m, 1);
100-
compute_encoder->dispatchThreads(grid_dims, group_dims);
100+
compute_encoder.dispatchThreads(grid_dims, group_dims);
101101
}
102102
d.get_command_buffer(s.index)->addCompletedHandler(
103103
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });

mlx/backend/metal/indexing.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
107107
}
108108

109109
// Launch grid
110-
compute_encoder->dispatchThreads(grid_dims, group_dims);
110+
compute_encoder.dispatchThreads(grid_dims, group_dims);
111111
}
112112

113113
void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
@@ -216,7 +216,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
216216
// Launch grid
217217
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
218218
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
219-
compute_encoder->dispatchThreads(grid_dims, group_dims);
219+
compute_encoder.dispatchThreads(grid_dims, group_dims);
220220

221221
} else {
222222
// Collect all idx shapes and strides into one place
@@ -286,7 +286,7 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
286286
// Launch grid
287287
MTL::Size grid_dims = MTL::Size(upd_size, nthreads / upd_size, 1);
288288
MTL::Size group_dims = get_block_dims(upd_size, nthreads / upd_size, 1);
289-
compute_encoder->dispatchThreads(grid_dims, group_dims);
289+
compute_encoder.dispatchThreads(grid_dims, group_dims);
290290
}
291291
}
292292

0 commit comments

Comments
 (0)