@@ -25,6 +25,7 @@ namespace {
2525
2626// TODO nicer way to set this or possibly expose as an environment variable
2727constexpr int MAX_BUFFERS_PER_QUEUE = 12 ;
28+ constexpr int MAX_DISPATCHES_PER_ENCODER = 2 ;
2829
2930constexpr const char * default_mtllib_path = METAL_PATH;
3031
@@ -37,7 +38,6 @@ auto load_device() {
3738 }
3839 return device;
3940}
40-
4141std::pair<MTL::Library*, NS::Error*> load_library_from_path (
4242 MTL::Device* device,
4343 const char * path) {
@@ -116,6 +116,33 @@ MTL::Library* load_library(
116116
117117} // namespace
118118
119+ void CommandEncoder::dispatchThreadgroups (
120+ MTL::Size grid_dims,
121+ MTL::Size group_dims) {
122+ num_dispatches++;
123+ enc->dispatchThreadgroups (grid_dims, group_dims);
124+ maybe_split ();
125+ }
126+
127+ void CommandEncoder::dispatchThreads (
128+ MTL::Size grid_dims,
129+ MTL::Size group_dims) {
130+ num_dispatches++;
131+ enc->dispatchThreads (grid_dims, group_dims);
132+ maybe_split ();
133+ }
134+
135+ void CommandEncoder::maybe_split () {
136+ if (num_dispatches > MAX_DISPATCHES_PER_ENCODER && !concurrent) {
137+ enc->endEncoding ();
138+ enc->release ();
139+ num_dispatches = 0 ;
140+ outputs.clear ();
141+ enc = cbuf->computeCommandEncoder (MTL::DispatchTypeConcurrent);
142+ enc->retain ();
143+ }
144+ }
145+
119146Device::Device () {
120147 auto pool = new_scoped_memory_pool ();
121148 device_ = load_device ();
@@ -130,9 +157,6 @@ Device::~Device() {
130157 for (auto & b : buffer_map_) {
131158 b.second .second ->release ();
132159 }
133- for (auto & e : encoder_map_) {
134- (*e.second )->release ();
135- }
136160 for (auto & k : kernel_map_) {
137161 k.second ->release ();
138162 }
@@ -169,27 +193,26 @@ void Device::increment_command_buffer_ops(int index) {
169193
170194MTL::CommandBuffer* Device::get_command_buffer (int index) {
171195 auto bit = buffer_map_.find (index);
172- return (bit == buffer_map_.end ()) ? nullptr : bit->second .second ;
173- }
174-
175- MTL::CommandBuffer* Device::new_command_buffer (int index) {
176- auto qit = queue_map_.find (index);
177- if (qit == queue_map_.end ()) {
178- throw std::runtime_error (
179- " [metal::Device] Attempting to get command buffer for invalid queue." );
180- }
196+ if (bit == buffer_map_.end ()) {
197+ auto qit = queue_map_.find (index);
198+ if (qit == queue_map_.end ()) {
199+ throw std::runtime_error (
200+ " [metal::Device] Attempting to get command buffer for invalid queue." );
201+ }
181202
182- auto cb = qit->second ->commandBufferWithUnretainedReferences ();
203+ auto cb = qit->second ->commandBufferWithUnretainedReferences ();
183204
184- if (!cb) {
185- throw std::runtime_error (
186- " [metal::Device] Unable to create new command buffer" );
187- }
205+ if (!cb) {
206+ throw std::runtime_error (
207+ " [metal::Device] Unable to create new command buffer" );
208+ }
188209
189- // Increment ref count so the buffer is not garbage collected
190- cb->retain ();
210+ // Increment ref count so the buffer is not garbage collected
211+ cb->retain ();
191212
192- return buffer_map_.insert ({index, {0 , cb}}).first ->second .second ;
213+ bit = buffer_map_.insert ({index, {0 , cb}}).first ;
214+ }
215+ return bit->second .second ;
193216}
194217
195218void Device::commit_command_buffer (int index) {
@@ -200,25 +223,15 @@ void Device::commit_command_buffer(int index) {
200223}
201224
202225void Device::end_encoding (int index) {
203- auto eit = encoder_map_.find (index);
204- if (eit != encoder_map_.end ()) {
205- (*eit->second )->endEncoding ();
206- (*eit->second )->release ();
207- encoder_map_.erase (eit);
208- }
226+ encoder_map_.erase (index);
209227}
210228
211229CommandEncoder& Device::get_command_encoder (int index) {
212230 auto eit = encoder_map_.find (index);
213231 if (eit == encoder_map_.end ()) {
214232 auto cb = get_command_buffer (index);
215- auto compute_encoder =
216- cb->computeCommandEncoder (MTL::DispatchTypeConcurrent);
217- // Increment ref count so the buffer is not garbage collected
218- compute_encoder->retain ();
219- eit = encoder_map_
220- .emplace (index, std::make_unique<CommandEncoder>(compute_encoder))
221- .first ;
233+ eit =
234+ encoder_map_.emplace (index, std::make_unique<CommandEncoder>(cb)).first ;
222235 }
223236 return *(eit->second );
224237}
0 commit comments