11// Copyright © 2023 Apple Inc.
2-
32#include < algorithm>
43#include < cassert>
54#include < numeric>
@@ -33,6 +32,9 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
3332 }
3433
3534 out.set_data (allocator::malloc_or_wait (out.nbytes ()));
35+ if (out.size () == 0 ) {
36+ return ;
37+ }
3638
3739 auto & s = stream ();
3840 auto & d = metal::device (s.device );
@@ -110,14 +112,18 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
110112 for (int i = 0 ; i < nidx; ++i) {
111113 set_array_buffer (compute_encoder, arg_enc, inputs[i + 1 ], i);
112114 }
113- arg_enc->setBuffer (
114- static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()), 0 , nidx + 1 );
115- compute_encoder->useResource (
116- static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()), MTL::ResourceUsageRead);
117- arg_enc->setBuffer (
118- static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()), 0 , nidx + 2 );
119- compute_encoder->useResource (
120- static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()), MTL::ResourceUsageRead);
115+ if (idx_ndim > 0 ) {
116+ arg_enc->setBuffer (
117+ static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()), 0 , nidx + 1 );
118+ compute_encoder->useResource (
119+ static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()),
120+ MTL::ResourceUsageRead);
121+ arg_enc->setBuffer (
122+ static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()), 0 , nidx + 2 );
123+ compute_encoder->useResource (
124+ static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()),
125+ MTL::ResourceUsageRead);
126+ }
121127 *static_cast <int *>(arg_enc->constantData (nidx + 3 )) = idx_ndim;
122128
123129 // Set all the buffers
@@ -163,6 +169,11 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
163169 inputs[0 ].data_size () == 1 ? CopyType::Scalar : CopyType::General;
164170 copy_gpu (inputs[0 ], out, copy_type);
165171
172+ // Empty update
173+ if (inputs.back ().size () == 0 ) {
174+ return ;
175+ }
176+
166177 // Get stream
167178 auto & s = stream ();
168179 auto & d = metal::device (s.device );
@@ -254,14 +265,18 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
254265 for (int i = 0 ; i < nidx; ++i) {
255266 set_array_buffer (compute_encoder, arg_enc, inputs[i + 1 ], i);
256267 }
257- arg_enc->setBuffer (
258- static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()), 0 , nidx + 1 );
259- compute_encoder->useResource (
260- static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()), MTL::ResourceUsageRead);
261- arg_enc->setBuffer (
262- static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()), 0 , nidx + 2 );
263- compute_encoder->useResource (
264- static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()), MTL::ResourceUsageRead);
268+ if (idx_ndim > 0 ) {
269+ arg_enc->setBuffer (
270+ static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()), 0 , nidx + 1 );
271+ compute_encoder->useResource (
272+ static_cast <MTL::Buffer*>(idx_shapes_buf.ptr ()),
273+ MTL::ResourceUsageRead);
274+ arg_enc->setBuffer (
275+ static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()), 0 , nidx + 2 );
276+ compute_encoder->useResource (
277+ static_cast <MTL::Buffer*>(idx_strides_buf.ptr ()),
278+ MTL::ResourceUsageRead);
279+ }
265280 *static_cast <int *>(arg_enc->constantData (nidx + 3 )) = idx_ndim;
266281
267282 compute_encoder->setBuffer (static_cast <MTL::Buffer*>(arg_buf.ptr ()), 0 , 0 );
@@ -272,14 +287,32 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
272287 }
273288 set_array_buffer (compute_encoder, upd, 1 );
274289 set_array_buffer (compute_encoder, out, 2 );
275- compute_encoder->setBytes (upd.shape ().data (), upd_ndim * sizeof (int ), 3 );
276- compute_encoder->setBytes (upd.strides ().data (), upd_ndim * sizeof (size_t ), 4 );
290+ if (upd_ndim == 0 ) {
291+ // Need placeholders so Metal doesn't compalain
292+ int shape_ = 0 ;
293+ size_t stride_ = 0 ;
294+ compute_encoder->setBytes (&shape_, sizeof (int ), 3 );
295+ compute_encoder->setBytes (&stride_, sizeof (size_t ), 4 );
296+ } else {
297+ compute_encoder->setBytes (upd.shape ().data (), upd_ndim * sizeof (int ), 3 );
298+ compute_encoder->setBytes (
299+ upd.strides ().data (), upd_ndim * sizeof (size_t ), 4 );
300+ }
277301 compute_encoder->setBytes (&upd_ndim, sizeof (size_t ), 5 );
278302 compute_encoder->setBytes (&upd_size, sizeof (size_t ), 6 );
279303
280304 size_t out_ndim = out.ndim ();
281- compute_encoder->setBytes (out.shape ().data (), out_ndim * sizeof (int ), 7 );
282- compute_encoder->setBytes (out.strides ().data (), out_ndim * sizeof (size_t ), 8 );
305+ if (out_ndim == 0 ) {
306+ // Need placeholders so Metal doesn't compalain
307+ int shape_ = 0 ;
308+ size_t stride_ = 0 ;
309+ compute_encoder->setBytes (&shape_, sizeof (int ), 7 );
310+ compute_encoder->setBytes (&stride_, sizeof (size_t ), 8 );
311+ } else {
312+ compute_encoder->setBytes (out.shape ().data (), out_ndim * sizeof (int ), 7 );
313+ compute_encoder->setBytes (
314+ out.strides ().data (), out_ndim * sizeof (size_t ), 8 );
315+ }
283316 compute_encoder->setBytes (&out_ndim, sizeof (size_t ), 9 );
284317 compute_encoder->setBytes (axes_.data (), axes_.size () * sizeof (int ), 10 );
285318
0 commit comments