1313#include " mlx/backend/metal/metal.h"
1414#include " mlx/backend/metal/metal_impl.h"
1515#include " mlx/backend/metal/utils.h"
16+ #include " mlx/utils.h"
1617
1718namespace mlx ::core::metal {
1819
@@ -124,8 +125,8 @@ MTL::Library* load_library(
124125
125126} // namespace
126127
127- CommandEncoder::CommandEncoder (MTL::CommandBuffer* cbuf ) {
128- enc_ = cbuf ->computeCommandEncoder (MTL::DispatchTypeConcurrent);
128+ CommandEncoder::CommandEncoder (DeviceStream& stream) : stream_(stream ) {
129+ enc_ = stream_. buffer ->computeCommandEncoder (MTL::DispatchTypeConcurrent);
129130 enc_->retain ();
130131}
131132
@@ -145,7 +146,9 @@ void CommandEncoder::set_input_array(
145146 const array& a,
146147 int idx,
147148 int64_t offset /* = 0 */ ) {
148- all_inputs_.insert (a.buffer ().ptr ());
149+ if (all_inputs_.insert (a.buffer ().ptr ()).second ) {
150+ stream_.buffer_sizes += a.data_size ();
151+ }
149152 auto r_buf = static_cast <MTL::Resource*>(const_cast <void *>(a.buffer ().ptr ()));
150153 needs_barrier_ =
151154 needs_barrier_ | (prev_outputs_.find (r_buf) != prev_outputs_.end ());
@@ -190,13 +193,15 @@ void CommandEncoder::dispatch_threadgroups(
190193 MTL::Size grid_dims,
191194 MTL::Size group_dims) {
192195 maybeInsertBarrier ();
196+ stream_.buffer_ops ++;
193197 enc_->dispatchThreadgroups (grid_dims, group_dims);
194198}
195199
196200void CommandEncoder::dispatch_threads (
197201 MTL::Size grid_dims,
198202 MTL::Size group_dims) {
199203 maybeInsertBarrier ();
204+ stream_.buffer_ops ++;
200205 enc_->dispatchThreads (grid_dims, group_dims);
201206}
202207
@@ -209,6 +214,31 @@ Device::Device() {
209214 device_ = load_device ();
210215 library_map_ = {{" mlx" , load_library (device_)}};
211216 arch_ = std::string (device_->architecture ()->name ()->utf8String ());
217+ auto arch = arch_.back ();
218+ switch (arch) {
219+ case ' p' : // phone
220+ max_ops_per_buffer_ = 20 ;
221+ max_mb_per_buffer_ = 40 ;
222+ break ;
223+ case ' g' : // base, pro
224+ max_ops_per_buffer_ = 40 ;
225+ max_mb_per_buffer_ = 40 ;
226+ break ;
227+ case ' s' : // max
228+ max_ops_per_buffer_ = 50 ;
229+ max_mb_per_buffer_ = 50 ;
230+ break ;
231+ case ' d' : // ultra
232+ max_ops_per_buffer_ = 50 ;
233+ max_mb_per_buffer_ = 50 ;
234+ break ;
235+ default : // default to medium
236+ max_ops_per_buffer_ = 40 ;
237+ max_mb_per_buffer_ = 40 ;
238+ break ;
239+ }
240+ max_ops_per_buffer_ = env::max_ops_per_buffer (max_ops_per_buffer_);
241+ max_mb_per_buffer_ = env::max_mb_per_buffer (max_mb_per_buffer_);
212242}
213243
214244Device::~Device () {
@@ -239,12 +269,13 @@ void Device::new_queue(int index) {
239269 }
240270}
241271
242- int Device::get_command_buffer_ops (int index) {
243- return get_stream_ (index).buffer_ops ;
244- }
245-
246- void Device::increment_command_buffer_ops (int index) {
247- get_stream_ (index).buffer_ops ++;
272+ bool Device::command_buffer_needs_commit (int index) {
273+ auto & stream = get_stream_ (index);
274+ if (stream.buffer_ops > max_ops_per_buffer_ ||
275+ (stream.buffer_sizes >> 20 ) > max_mb_per_buffer_) {
276+ return true ;
277+ }
278+ return false ;
248279}
249280
250281MTL::CommandBuffer* Device::get_command_buffer (int index) {
@@ -267,6 +298,7 @@ void Device::commit_command_buffer(int index) {
267298 stream.buffer ->release ();
268299 stream.buffer = nullptr ;
269300 stream.buffer_ops = 0 ;
301+ stream.buffer_sizes = 0 ;
270302}
271303
272304void Device::add_temporary (array arr, int index) {
@@ -351,7 +383,7 @@ void Device::end_encoding(int index) {
351383CommandEncoder& Device::get_command_encoder (int index) {
352384 auto & stream = get_stream_ (index);
353385 if (stream.encoder == nullptr ) {
354- stream.encoder = std::make_unique<CommandEncoder>(stream. buffer );
386+ stream.encoder = std::make_unique<CommandEncoder>(stream);
355387 stream.fence = std::make_shared<Fence>(device_->newFence ());
356388 }
357389 return *stream.encoder ;
0 commit comments