@@ -113,39 +113,38 @@ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
113113 // Collect all idx shapes and strides into one place
114114 std::vector<int > idx_shapes;
115115 std::vector<size_t > idx_strides;
116-
116+ std::vector< char > idx_contigs;
117117 for (int i = 0 ; i < nidx; ++i) {
118118 idx_shapes.insert (
119119 idx_shapes.end (),
120120 inputs[i + 1 ].shape ().begin (),
121121 inputs[i + 1 ].shape ().end ());
122-
123122 idx_strides.insert (
124123 idx_strides.end (),
125124 inputs[i + 1 ].strides ().begin (),
126125 inputs[i + 1 ].strides ().end ());
126+ idx_contigs.push_back (inputs[i + 1 ].flags ().row_contiguous );
127127 }
128128
129129 // Set all the buffers
130130 compute_encoder.set_input_array (src, 0 );
131131 compute_encoder.set_output_array (out, 1 );
132132
133133 // Set source info
134- compute_encoder-> setBytes ( src.shape (). data (), ndim * sizeof ( int ), 2 );
135- compute_encoder-> setBytes ( src.strides (). data (), ndim * sizeof ( size_t ), 3 );
134+ set_vector_bytes (compute_encoder, src.shape (), 2 );
135+ set_vector_bytes (compute_encoder, src.strides (), 3 );
136136 compute_encoder->setBytes (&ndim, sizeof (size_t ), 4 );
137- compute_encoder-> setBytes (slice_sizes_. data (), ndim * sizeof ( int ) , 5 );
138- compute_encoder-> setBytes (axes_. data (), nidx * sizeof ( int ) , 6 );
137+ set_vector_bytes (compute_encoder, slice_sizes_ , 5 );
138+ set_vector_bytes (compute_encoder, axes_ , 6 );
139139
140140 // Set index info
141141 //
142142 // We don't need to check for empty idx_shapes because gather has a
143143 // idx_ndim == 0 specialization
144- compute_encoder->setBytes (
145- idx_shapes.data (), idx_shapes.size () * sizeof (int ), 7 );
146- compute_encoder->setBytes (
147- idx_strides.data (), idx_strides.size () * sizeof (size_t ), 8 );
148- compute_encoder->setBytes (&idx_ndim, sizeof (int ), 9 );
144+ set_vector_bytes (compute_encoder, idx_shapes, 7 );
145+ set_vector_bytes (compute_encoder, idx_strides, 8 );
146+ set_vector_bytes (compute_encoder, idx_contigs, 9 );
147+ compute_encoder->setBytes (&idx_ndim, sizeof (int ), 10 );
149148
150149 // Set index buffers
151150 for (int i = 0 ; i < nidx; ++i) {
@@ -172,12 +171,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
172171 }
173172
174173 // Copy src into out
175- auto copy_type =
176- inputs[0 ].data_size () == 1 ? CopyType::Scalar : CopyType::General;
174+ CopyType copy_type;
175+ if (inputs[0 ].data_size () == 1 ) {
176+ copy_type = CopyType::Scalar;
177+ } else if (inputs[0 ].flags ().row_contiguous ) {
178+ copy_type = CopyType::Vector;
179+ } else {
180+ copy_type = CopyType::General;
181+ }
177182 copy_gpu (inputs[0 ], out, copy_type);
178183
184+ auto & upd = inputs.back ();
185+
179186 // Empty update
180- if (inputs. back () .size () == 0 ) {
187+ if (upd .size () == 0 ) {
181188 return ;
182189 }
183190
@@ -186,19 +193,20 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
186193 auto & d = metal::device (s.device );
187194
188195 int idx_ndim = nidx ? inputs[1 ].ndim () : 0 ;
189- bool index_nd1_specialization = (idx_ndim == 1 );
190-
191- // Bail from fast path (1d index specialization) if scatter dims aren't
192- // the outermost dims and contiguous since update access won't be raster
193- // order.
194- for (auto i = 0 ; i < axes_.size () && index_nd1_specialization; i++) {
195- index_nd1_specialization &= (axes_[i] == i);
196- }
197-
198- // Bail from fast path (1d index specialization) if any of the dims are
199- // broadcasted, since we can't rely on linear indexing in that case.
200- for (int i = 1 ; i < inputs.size () && index_nd1_specialization; i++) {
201- index_nd1_specialization &= inputs[i].flags ().row_contiguous ;
196+ size_t idx_size = nidx ? inputs[1 ].size () : 1 ;
197+
198+ auto idx_to_out = idx_size / out.size ();
199+ int nwork;
200+ if (idx_ndim <= 1 || idx_to_out < 1 ) {
201+ nwork = 1 ;
202+ } else if (idx_to_out <= 4 ) {
203+ nwork = 4 ;
204+ } else if (idx_to_out < 16 ) {
205+ nwork = 8 ;
206+ } else if (idx_to_out < 32 ) {
207+ nwork = 16 ;
208+ } else {
209+ nwork = 32 ;
202210 }
203211
204212 std::string lib_name;
@@ -222,19 +230,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
222230 op_name = " min" ;
223231 break ;
224232 }
225-
233+ auto upd_contig = upd. flags (). row_contiguous ;
226234 {
227235 std::ostringstream kname;
228- if (index_nd1_specialization) {
229- kname << " scatter_1d_index" << type_to_name (out) << idx_type_name;
230- } else {
231- kname << " scatter" << type_to_name (out) << idx_type_name;
232- }
233- kname << " _" << op_name << " _" << nidx;
236+ kname << " scatter" << type_to_name (out) << idx_type_name;
237+ kname << " _" << op_name << " _" << nidx << " _"
238+ << (upd_contig ? " updc_true" : " updc_false" ) << " _nwork" << nwork;
234239 lib_name = kname.str ();
235240 kernel_name = kname.str ();
236241 }
237-
238242 auto lib = d.get_library (lib_name, [&]() {
239243 std::ostringstream kernel_source;
240244 kernel_source << metal::utils () << metal::reduce_utils ()
@@ -274,14 +278,15 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
274278 op_type,
275279 nidx,
276280 idx_args,
277- idx_arr);
281+ idx_arr,
282+ upd_contig,
283+ nwork);
278284 return kernel_source.str ();
279285 });
280286
281287 auto & compute_encoder = d.get_command_encoder (s.index );
282288 auto kernel = d.get_kernel (kernel_name, lib);
283289
284- auto & upd = inputs.back ();
285290 size_t nthreads = upd.size ();
286291
287292 compute_encoder->setComputePipelineState (kernel);
@@ -291,109 +296,86 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
291296 compute_encoder.set_output_array (out, 2 );
292297
293298 // Set update info
294- uint upd_ndim = upd.ndim ();
299+ size_t upd_ndim = upd.ndim ();
295300 size_t upd_size = 1 ;
296301 for (int i = idx_ndim; i < upd.ndim (); ++i) {
297302 upd_size *= upd.shape (i);
298303 }
299- if (index_nd1_specialization) {
300- compute_encoder->setBytes (
301- out.shape ().data (), out.shape ().size () * sizeof (int ), 3 );
302- compute_encoder->setBytes (
303- out.strides ().data (), out.strides ().size () * sizeof (size_t ), 4 );
304-
305- size_t out_ndim = out.ndim ();
306- compute_encoder->setBytes (&out_ndim, sizeof (out_ndim), 5 );
307- if (upd_ndim <= 1 ) {
308- // Placeholder so Metal doesn't compalain
309- int shape_ = 0 ;
310- compute_encoder->setBytes (&shape_, sizeof (int ), 6 );
311- } else {
312- compute_encoder->setBytes (upd.shape ().data (), upd_ndim * sizeof (int ), 6 );
313- }
314- compute_encoder->setBytes (&upd_ndim, sizeof (size_t ), 7 );
315- compute_encoder->setBytes (&upd_size, sizeof (size_t ), 8 );
316-
317- // Set index buffers
318- for (int i = 0 ; i < nidx; ++i) {
319- compute_encoder.set_input_array (inputs[i + 1 ], 20 + i);
320- }
321-
322- // Launch grid
323- MTL::Size grid_dims = MTL::Size (upd_size, nthreads / upd_size, 1 );
324- MTL::Size group_dims = get_block_dims (upd_size, nthreads / upd_size, 1 );
325- compute_encoder.dispatchThreads (grid_dims, group_dims);
304+ // Collect all idx shapes and strides into one place
305+ std::vector<int > idx_shapes;
306+ std::vector<size_t > idx_strides;
307+ // To access .data() use char instead of bool
308+ // bool is 1 byte in Metal so this is safe
309+ std::vector<char > idx_contigs;
310+ for (int i = 0 ; i < nidx; ++i) {
311+ idx_shapes.insert (
312+ idx_shapes.end (),
313+ inputs[i + 1 ].shape ().begin (),
314+ inputs[i + 1 ].shape ().end ());
315+ idx_strides.insert (
316+ idx_strides.end (),
317+ inputs[i + 1 ].strides ().begin (),
318+ inputs[i + 1 ].strides ().end ());
319+ idx_contigs.push_back (inputs[i + 1 ].flags ().row_contiguous );
320+ }
326321
322+ if (upd_ndim == 0 ) {
323+ // Need placeholders so Metal doesn't compalain
324+ int shape_ = 0 ;
325+ size_t stride_ = 0 ;
326+ compute_encoder->setBytes (&shape_, sizeof (int ), 3 );
327+ compute_encoder->setBytes (&stride_, sizeof (size_t ), 4 );
327328 } else {
328- // Collect all idx shapes and strides into one place
329- std::vector<int > idx_shapes;
330- std::vector<size_t > idx_strides;
331-
332- for (int i = 0 ; i < nidx; ++i) {
333- idx_shapes.insert (
334- idx_shapes.end (),
335- inputs[i + 1 ].shape ().begin (),
336- inputs[i + 1 ].shape ().end ());
337-
338- idx_strides.insert (
339- idx_strides.end (),
340- inputs[i + 1 ].strides ().begin (),
341- inputs[i + 1 ].strides ().end ());
342- }
329+ set_vector_bytes (compute_encoder, upd.shape (), 3 );
330+ set_vector_bytes (compute_encoder, upd.strides (), 4 );
331+ }
332+ compute_encoder->setBytes (&upd_ndim, sizeof (size_t ), 5 );
333+ compute_encoder->setBytes (&upd_size, sizeof (size_t ), 6 );
334+
335+ // Set output info
336+ size_t out_ndim = out.ndim ();
337+ if (out_ndim == 0 ) {
338+ // Need placeholders so Metal doesn't compalain
339+ int shape_ = 0 ;
340+ size_t stride_ = 0 ;
341+ compute_encoder->setBytes (&shape_, sizeof (int ), 7 );
342+ compute_encoder->setBytes (&stride_, sizeof (size_t ), 8 );
343+ } else {
344+ set_vector_bytes (compute_encoder, out.shape (), 7 );
345+ set_vector_bytes (compute_encoder, out.strides (), 8 );
346+ }
347+ compute_encoder->setBytes (&out_ndim, sizeof (size_t ), 9 );
348+ compute_encoder->setBytes (axes_.data (), axes_.size () * sizeof (int ), 10 );
343349
344- if (upd_ndim == 0 ) {
345- // Need placeholders so Metal doesn't compalain
346- int shape_ = 0 ;
347- size_t stride_ = 0 ;
348- compute_encoder->setBytes (&shape_, sizeof (int ), 3 );
349- compute_encoder->setBytes (&stride_, sizeof (size_t ), 4 );
350- } else {
351- compute_encoder->setBytes (upd.shape ().data (), upd_ndim * sizeof (int ), 3 );
352- compute_encoder->setBytes (
353- upd.strides ().data (), upd_ndim * sizeof (size_t ), 4 );
354- }
355- compute_encoder->setBytes (&upd_ndim, sizeof (size_t ), 5 );
356- compute_encoder->setBytes (&upd_size, sizeof (size_t ), 6 );
357-
358- // Set output info
359- size_t out_ndim = out.ndim ();
360- if (out_ndim == 0 ) {
361- // Need placeholders so Metal doesn't compalain
362- int shape_ = 0 ;
363- size_t stride_ = 0 ;
364- compute_encoder->setBytes (&shape_, sizeof (int ), 7 );
365- compute_encoder->setBytes (&stride_, sizeof (size_t ), 8 );
366- } else {
367- compute_encoder->setBytes (out.shape ().data (), out_ndim * sizeof (int ), 7 );
368- compute_encoder->setBytes (
369- out.strides ().data (), out_ndim * sizeof (size_t ), 8 );
370- }
371- compute_encoder->setBytes (&out_ndim, sizeof (size_t ), 9 );
372- compute_encoder->setBytes (axes_.data (), axes_.size () * sizeof (int ), 10 );
373-
374- // Set index info
375- if (idx_ndim == 0 ) {
376- // Add a 0 in idx_shapes and strides to avoid the missing buffer binding
377- // error in the metal API.
378- idx_shapes.push_back (0 );
379- idx_strides.push_back (0 );
380- }
381- compute_encoder->setBytes (
382- idx_shapes.data (), idx_shapes.size () * sizeof (int ), 11 );
383- compute_encoder->setBytes (
384- idx_strides.data (), idx_strides.size () * sizeof (size_t ), 12 );
385- compute_encoder->setBytes (&idx_ndim, sizeof (int ), 13 );
386-
387- // Set index buffers
388- for (int i = 0 ; i < nidx; ++i) {
389- compute_encoder.set_input_array (inputs[i + 1 ], 20 + i);
390- }
350+ // Set index info
351+ if (idx_ndim == 0 ) {
352+ // Add a 0 in idx_shapes and strides to avoid the missing buffer binding
353+ // error in the metal API.
354+ idx_shapes.push_back (0 );
355+ idx_strides.push_back (0 );
356+ idx_contigs.push_back (false );
357+ }
358+ set_vector_bytes (compute_encoder, idx_shapes, 11 );
359+ set_vector_bytes (compute_encoder, idx_strides, 12 );
360+ set_vector_bytes (compute_encoder, idx_contigs, 13 );
361+ compute_encoder->setBytes (&idx_ndim, sizeof (int ), 14 );
362+ compute_encoder->setBytes (&idx_size, sizeof (size_t ), 15 );
363+
364+ // Set index buffers
365+ for (int i = 0 ; i < nidx; ++i) {
366+ compute_encoder.set_input_array (inputs[i + 1 ], 20 + i);
367+ }
391368
392- // Launch grid
393- MTL::Size grid_dims = MTL::Size (upd_size, nthreads / upd_size, 1 );
394- MTL::Size group_dims = get_block_dims (upd_size, nthreads / upd_size, 1 );
395- compute_encoder.dispatchThreads (grid_dims, group_dims);
369+ // Launch grid
370+ auto grid_y = (nthreads / upd_size);
371+ grid_y = (grid_y + nwork - 1 ) / nwork;
372+ MTL::Size grid_dims = MTL::Size (upd_size, grid_y, 1 );
373+ auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup ();
374+ if (thread_group_size != 1024 ) {
375+ throw std::runtime_error (" [Scatter::eval_gpu] Invalid number of threads" );
396376 }
377+ MTL::Size group_dims = get_block_dims (upd_size, grid_y, 1 );
378+ compute_encoder.dispatchThreads (grid_dims, group_dims);
397379}
398380
399381} // namespace mlx::core
0 commit comments