@@ -142,7 +142,28 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
142142 // Get kernel name
143143 std::ostringstream kname;
144144 std::string idx_type_name = nidx ? type_to_name (inputs[1 ]) : " " ;
145- kname << " scatter" << type_to_name (out) << idx_type_name;
145+
146+ int idx_ndim = nidx ? inputs[1 ].ndim () : 0 ;
147+ bool index_nd1_specialization = (idx_ndim == 1 );
148+
149+ // Bail from fast path (1d index specialization) if scatter dims aren't
150+ // the outermost dims and contiguous since update access won't be raster
151+ // order.
152+ for (auto i = 0 ; i < axes_.size () && index_nd1_specialization; i++) {
153+ index_nd1_specialization &= (axes_[i] == i);
154+ }
155+
156+ // Bail from fast path (1d index specialization) if any of the dims are
157+ // broadcasted, since we can't rely on linear indexing in that case.
158+ for (int i = 1 ; i < inputs.size () && index_nd1_specialization; i++) {
159+ index_nd1_specialization &= inputs[i].flags ().row_contiguous ;
160+ }
161+
162+ if (index_nd1_specialization) {
163+ kname << " scatter_1d_index" << type_to_name (out) << idx_type_name;
164+ } else {
165+ kname << " scatter" << type_to_name (out) << idx_type_name;
166+ }
146167 switch (reduce_type_) {
147168 case Scatter::None:
148169 kname << " _none" ;
@@ -170,85 +191,106 @@ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
170191
171192 compute_encoder->setComputePipelineState (kernel);
172193
173- // Collect all idx shapes and strides into one place
174- int idx_ndim = nidx ? inputs[1 ].ndim () : 0 ;
175- std::vector<int > idx_shapes;
176- std::vector<size_t > idx_strides;
177-
178- for (int i = 0 ; i < nidx; ++i) {
179- idx_shapes.insert (
180- idx_shapes.end (),
181- inputs[i + 1 ].shape ().begin (),
182- inputs[i + 1 ].shape ().end ());
183-
184- idx_strides.insert (
185- idx_strides.end (),
186- inputs[i + 1 ].strides ().begin (),
187- inputs[i + 1 ].strides ().end ());
188- }
189-
190194 // Set all the buffers
191195 set_array_buffer (compute_encoder, upd, 1 );
192196 set_array_buffer (compute_encoder, out, 2 );
193197
194198 // Set update info
195- size_t upd_ndim = upd.ndim ();
199+ uint upd_ndim = upd.ndim ();
196200 size_t upd_size = 1 ;
197201 for (int i = idx_ndim; i < upd.ndim (); ++i) {
198202 upd_size *= upd.shape (i);
199203 }
200- if (upd_ndim == 0 ) {
201- // Need placeholders so Metal doesn't compalain
202- int shape_ = 0 ;
203- size_t stride_ = 0 ;
204- compute_encoder->setBytes (&shape_, sizeof (int ), 3 );
205- compute_encoder->setBytes (&stride_, sizeof (size_t ), 4 );
206- } else {
207- compute_encoder->setBytes (upd.shape ().data (), upd_ndim * sizeof (int ), 3 );
204+
205+ if (index_nd1_specialization) {
206+ bool upd_col_contiguous = upd.flags ().col_contiguous ;
208207 compute_encoder->setBytes (
209- upd.strides ().data (), upd_ndim * sizeof (size_t ), 4 );
210- }
211- compute_encoder->setBytes (&upd_ndim, sizeof (size_t ), 5 );
212- compute_encoder->setBytes (&upd_size, sizeof (size_t ), 6 );
213-
214- // Set output info
215- size_t out_ndim = out.ndim ();
216- if (out_ndim == 0 ) {
217- // Need placeholders so Metal doesn't compalain
218- int shape_ = 0 ;
219- size_t stride_ = 0 ;
220- compute_encoder->setBytes (&shape_, sizeof (int ), 7 );
221- compute_encoder->setBytes (&stride_, sizeof (size_t ), 8 );
222- } else {
223- compute_encoder->setBytes (out.shape ().data (), out_ndim * sizeof (int ), 7 );
208+ out.shape ().data (), out.shape ().size () * sizeof (int ), 3 );
224209 compute_encoder->setBytes (
225- out.strides ().data (), out_ndim * sizeof (size_t ), 8 );
226- }
227- compute_encoder->setBytes (&out_ndim, sizeof (size_t ), 9 );
228- compute_encoder->setBytes (axes_.data (), axes_.size () * sizeof (int ), 10 );
210+ out.strides ().data (), out.strides ().size () * sizeof (size_t ), 4 );
211+ compute_encoder->setBytes (&upd_size, sizeof (size_t ), 5 );
212+ compute_encoder->setBytes (&upd_col_contiguous, sizeof (bool ), 6 );
229213
230- // Set index info
231- if (idx_ndim == 0 ) {
232- // Add a 0 in idx_shapes and strides to avoid the missing buffer binding
233- // error in the metal API.
234- idx_shapes.push_back (0 );
235- idx_strides.push_back (0 );
236- }
237- compute_encoder->setBytes (
238- idx_shapes.data (), idx_shapes.size () * sizeof (int ), 11 );
239- compute_encoder->setBytes (
240- idx_strides.data (), idx_strides.size () * sizeof (size_t ), 12 );
241- compute_encoder->setBytes (&idx_ndim, sizeof (int ), 13 );
214+ // Set index buffers
215+ for (int i = 1 ; i < nidx + 1 ; ++i) {
216+ set_array_buffer (compute_encoder, inputs[i], 20 + i);
217+ }
242218
243- // Set index buffers
244- for ( int i = 1 ; i < nidx + 1 ; ++i) {
245- set_array_buffer (compute_encoder, inputs[i], 20 + i );
246- }
219+ // Launch grid
220+ MTL::Size grid_dims = MTL::Size (upd_size, nthreads / upd_size, 1 );
221+ MTL::Size group_dims = get_block_dims (upd_size, nthreads / upd_size, 1 );
222+ compute_encoder-> dispatchThreads (grid_dims, group_dims);
247223
248- // Launch grid
249- MTL::Size grid_dims = MTL::Size (upd_size, nthreads / upd_size, 1 );
250- MTL::Size group_dims = get_block_dims (upd_size, nthreads / upd_size, 1 );
251- compute_encoder->dispatchThreads (grid_dims, group_dims);
224+ } else {
225+ // Collect all idx shapes and strides into one place
226+ std::vector<int > idx_shapes;
227+ std::vector<size_t > idx_strides;
228+
229+ for (int i = 0 ; i < nidx; ++i) {
230+ idx_shapes.insert (
231+ idx_shapes.end (),
232+ inputs[i + 1 ].shape ().begin (),
233+ inputs[i + 1 ].shape ().end ());
234+
235+ idx_strides.insert (
236+ idx_strides.end (),
237+ inputs[i + 1 ].strides ().begin (),
238+ inputs[i + 1 ].strides ().end ());
239+ }
240+
241+ if (upd_ndim == 0 ) {
242+ // Need placeholders so Metal doesn't compalain
243+ int shape_ = 0 ;
244+ size_t stride_ = 0 ;
245+ compute_encoder->setBytes (&shape_, sizeof (int ), 3 );
246+ compute_encoder->setBytes (&stride_, sizeof (size_t ), 4 );
247+ } else {
248+ compute_encoder->setBytes (upd.shape ().data (), upd_ndim * sizeof (int ), 3 );
249+ compute_encoder->setBytes (
250+ upd.strides ().data (), upd_ndim * sizeof (size_t ), 4 );
251+ }
252+ compute_encoder->setBytes (&upd_ndim, sizeof (size_t ), 5 );
253+ compute_encoder->setBytes (&upd_size, sizeof (size_t ), 6 );
254+
255+ // Set output info
256+ size_t out_ndim = out.ndim ();
257+ if (out_ndim == 0 ) {
258+ // Need placeholders so Metal doesn't compalain
259+ int shape_ = 0 ;
260+ size_t stride_ = 0 ;
261+ compute_encoder->setBytes (&shape_, sizeof (int ), 7 );
262+ compute_encoder->setBytes (&stride_, sizeof (size_t ), 8 );
263+ } else {
264+ compute_encoder->setBytes (out.shape ().data (), out_ndim * sizeof (int ), 7 );
265+ compute_encoder->setBytes (
266+ out.strides ().data (), out_ndim * sizeof (size_t ), 8 );
267+ }
268+ compute_encoder->setBytes (&out_ndim, sizeof (size_t ), 9 );
269+ compute_encoder->setBytes (axes_.data (), axes_.size () * sizeof (int ), 10 );
270+
271+ // Set index info
272+ if (idx_ndim == 0 ) {
273+ // Add a 0 in idx_shapes and strides to avoid the missing buffer binding
274+ // error in the metal API.
275+ idx_shapes.push_back (0 );
276+ idx_strides.push_back (0 );
277+ }
278+ compute_encoder->setBytes (
279+ idx_shapes.data (), idx_shapes.size () * sizeof (int ), 11 );
280+ compute_encoder->setBytes (
281+ idx_strides.data (), idx_strides.size () * sizeof (size_t ), 12 );
282+ compute_encoder->setBytes (&idx_ndim, sizeof (int ), 13 );
283+
284+ // Set index buffers
285+ for (int i = 1 ; i < nidx + 1 ; ++i) {
286+ set_array_buffer (compute_encoder, inputs[i], 20 + i);
287+ }
288+
289+ // Launch grid
290+ MTL::Size grid_dims = MTL::Size (upd_size, nthreads / upd_size, 1 );
291+ MTL::Size group_dims = get_block_dims (upd_size, nthreads / upd_size, 1 );
292+ compute_encoder->dispatchThreads (grid_dims, group_dims);
293+ }
252294}
253295
254296} // namespace mlx::core
0 commit comments