Skip to content

Commit 7aea5b1

Browse files
authored
Allow dynamic ops per buffer based on dispatches and memory (#1864)
* Allow dynamic ops per buffer based on dispatches and memory * add initial arch values
1 parent 9733e16 commit 7aea5b1

File tree

5 files changed

+62
-21
lines changed

5 files changed

+62
-21
lines changed

mlx/backend/metal/device.cpp

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlx/backend/metal/metal.h"
1414
#include "mlx/backend/metal/metal_impl.h"
1515
#include "mlx/backend/metal/utils.h"
16+
#include "mlx/utils.h"
1617

1718
namespace mlx::core::metal {
1819

@@ -124,8 +125,8 @@ MTL::Library* load_library(
124125

125126
} // namespace
126127

127-
CommandEncoder::CommandEncoder(MTL::CommandBuffer* cbuf) {
128-
enc_ = cbuf->computeCommandEncoder(MTL::DispatchTypeConcurrent);
128+
CommandEncoder::CommandEncoder(DeviceStream& stream) : stream_(stream) {
129+
enc_ = stream_.buffer->computeCommandEncoder(MTL::DispatchTypeConcurrent);
129130
enc_->retain();
130131
}
131132

@@ -145,7 +146,9 @@ void CommandEncoder::set_input_array(
145146
const array& a,
146147
int idx,
147148
int64_t offset /* = 0 */) {
148-
all_inputs_.insert(a.buffer().ptr());
149+
if (all_inputs_.insert(a.buffer().ptr()).second) {
150+
stream_.buffer_sizes += a.data_size();
151+
}
149152
auto r_buf = static_cast<MTL::Resource*>(const_cast<void*>(a.buffer().ptr()));
150153
needs_barrier_ =
151154
needs_barrier_ | (prev_outputs_.find(r_buf) != prev_outputs_.end());
@@ -190,13 +193,15 @@ void CommandEncoder::dispatch_threadgroups(
190193
MTL::Size grid_dims,
191194
MTL::Size group_dims) {
192195
maybeInsertBarrier();
196+
stream_.buffer_ops++;
193197
enc_->dispatchThreadgroups(grid_dims, group_dims);
194198
}
195199

196200
void CommandEncoder::dispatch_threads(
197201
MTL::Size grid_dims,
198202
MTL::Size group_dims) {
199203
maybeInsertBarrier();
204+
stream_.buffer_ops++;
200205
enc_->dispatchThreads(grid_dims, group_dims);
201206
}
202207

@@ -209,6 +214,31 @@ Device::Device() {
209214
device_ = load_device();
210215
library_map_ = {{"mlx", load_library(device_)}};
211216
arch_ = std::string(device_->architecture()->name()->utf8String());
217+
auto arch = arch_.back();
218+
switch (arch) {
219+
case 'p': // phone
220+
max_ops_per_buffer_ = 20;
221+
max_mb_per_buffer_ = 40;
222+
break;
223+
case 'g': // base, pro
224+
max_ops_per_buffer_ = 40;
225+
max_mb_per_buffer_ = 40;
226+
break;
227+
case 's': // max
228+
max_ops_per_buffer_ = 50;
229+
max_mb_per_buffer_ = 50;
230+
break;
231+
case 'd': // ultra
232+
max_ops_per_buffer_ = 50;
233+
max_mb_per_buffer_ = 50;
234+
break;
235+
default: // default to medium
236+
max_ops_per_buffer_ = 40;
237+
max_mb_per_buffer_ = 40;
238+
break;
239+
}
240+
max_ops_per_buffer_ = env::max_ops_per_buffer(max_ops_per_buffer_);
241+
max_mb_per_buffer_ = env::max_mb_per_buffer(max_mb_per_buffer_);
212242
}
213243

214244
Device::~Device() {
@@ -239,12 +269,13 @@ void Device::new_queue(int index) {
239269
}
240270
}
241271

242-
int Device::get_command_buffer_ops(int index) {
243-
return get_stream_(index).buffer_ops;
244-
}
245-
246-
void Device::increment_command_buffer_ops(int index) {
247-
get_stream_(index).buffer_ops++;
272+
bool Device::command_buffer_needs_commit(int index) {
273+
auto& stream = get_stream_(index);
274+
if (stream.buffer_ops > max_ops_per_buffer_ ||
275+
(stream.buffer_sizes >> 20) > max_mb_per_buffer_) {
276+
return true;
277+
}
278+
return false;
248279
}
249280

250281
MTL::CommandBuffer* Device::get_command_buffer(int index) {
@@ -267,6 +298,7 @@ void Device::commit_command_buffer(int index) {
267298
stream.buffer->release();
268299
stream.buffer = nullptr;
269300
stream.buffer_ops = 0;
301+
stream.buffer_sizes = 0;
270302
}
271303

272304
void Device::add_temporary(array arr, int index) {
@@ -351,7 +383,7 @@ void Device::end_encoding(int index) {
351383
CommandEncoder& Device::get_command_encoder(int index) {
352384
auto& stream = get_stream_(index);
353385
if (stream.encoder == nullptr) {
354-
stream.encoder = std::make_unique<CommandEncoder>(stream.buffer);
386+
stream.encoder = std::make_unique<CommandEncoder>(stream);
355387
stream.fence = std::make_shared<Fence>(device_->newFence());
356388
}
357389
return *stream.encoder;

mlx/backend/metal/device.h

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@ inline std::string get_colocated_mtllib_path(const std::string& lib_name) {
3838
using MTLFCList =
3939
std::vector<std::tuple<const void*, MTL::DataType, NS::UInteger>>;
4040

41+
struct DeviceStream;
42+
4143
struct CommandEncoder {
42-
CommandEncoder(MTL::CommandBuffer* cbuf);
44+
explicit CommandEncoder(DeviceStream& stream);
4345
CommandEncoder(const CommandEncoder&) = delete;
4446
CommandEncoder& operator=(const CommandEncoder&) = delete;
4547

@@ -115,6 +117,7 @@ struct CommandEncoder {
115117
void barrier();
116118

117119
private:
120+
DeviceStream& stream_;
118121
MTL::ComputeCommandEncoder* enc_;
119122
bool needs_barrier_{false};
120123
bool concurrent_{false};
@@ -147,10 +150,10 @@ struct DeviceStream {
147150
// Used to allow thread-safe access to the outputs map
148151
std::mutex fence_mtx;
149152

150-
// The buffer and buffer op count are updated
151-
// between command buffers
153+
// Data updated between command buffers
152154
MTL::CommandBuffer* buffer{nullptr};
153155
int buffer_ops{0};
156+
size_t buffer_sizes{0};
154157

155158
// The command encoder, fence, and temporaries are updated between command
156159
// encoders
@@ -176,8 +179,7 @@ class Device {
176179

177180
void new_queue(int index);
178181
MTL::CommandBuffer* get_command_buffer(int index);
179-
int get_command_buffer_ops(int index);
180-
void increment_command_buffer_ops(int index);
182+
bool command_buffer_needs_commit(int index);
181183
void commit_command_buffer(int index);
182184
CommandEncoder& get_command_encoder(int index);
183185
void end_encoding(int index);
@@ -267,6 +269,8 @@ class Device {
267269
std::unordered_map<std::string, MTL::Library*> library_map_;
268270
const MTL::ResidencySet* residency_set_{nullptr};
269271
std::string arch_;
272+
int max_ops_per_buffer_;
273+
int max_mb_per_buffer_;
270274
};
271275

272276
Device& device(mlx::core::Device);

mlx/backend/metal/matmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ std::tuple<bool, int64_t, array> check_transpose(
109109
///////////////////////////////////////////////////////////////////////////////
110110

111111
#define GEMM_TPARAM_MACRO(devc) \
112-
if (devc == 'g') { /* Small device */ \
112+
if (devc == 'g' || devc == 'p') { /* Small device */ \
113113
if (!transpose_a && transpose_b) { /* nt */ \
114114
bm = 64; \
115115
bn = 32; \

mlx/backend/metal/metal.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ std::function<void()> make_task(array arr, bool signal) {
2929
auto s = arr.primitive().stream();
3030
auto& d = metal::device(s.device);
3131
auto command_buffer = d.get_command_buffer(s.index);
32-
d.increment_command_buffer_ops(s.index);
3332

3433
for (auto& input : arr.inputs()) {
3534
if (input.event().valid() &&
@@ -68,8 +67,7 @@ std::function<void()> make_task(array arr, bool signal) {
6867
out.set_status(array::Status::evaluated);
6968
}
7069

71-
if (signal ||
72-
d.get_command_buffer_ops(s.index) >= env::max_ops_per_buffer()) {
70+
if (signal || d.command_buffer_needs_commit(s.index)) {
7371
if (signal) {
7472
encode_signal(arr.event());
7573
}

mlx/utils.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,18 @@ inline int bfs_max_width() {
122122
return bfs_max_width_;
123123
}
124124

125-
inline int max_ops_per_buffer() {
126-
static int max_ops_per_buffer_ = get_var("MLX_MAX_OPS_PER_BUFFER", 10);
125+
inline int max_ops_per_buffer(int default_value) {
126+
static int max_ops_per_buffer_ =
127+
get_var("MLX_MAX_OPS_PER_BUFFER", default_value);
127128
return max_ops_per_buffer_;
128129
}
129130

131+
inline int max_mb_per_buffer(int default_value) {
132+
static int max_mb_per_buffer_ =
133+
get_var("MLX_MAX_MB_PER_BUFFER", default_value);
134+
return max_mb_per_buffer_;
135+
}
136+
130137
inline bool metal_fast_synch() {
131138
static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0);
132139
return metal_fast_synch;

0 commit comments

Comments
 (0)