1- Developer Documentation
2- =======================
1+ Custom Extensions in MLX
2+ ========================
33
44You can extend MLX with custom operations on the CPU or GPU. This guide
55explains how to do that with a simple example.
@@ -494,7 +494,7 @@ below.
494494 auto kernel = d.get_kernel(kname.str(), "mlx_ext");
495495
496496 // Prepare to encode kernel
497- auto compute_encoder = d.get_command_encoder(s.index);
497+ auto& compute_encoder = d.get_command_encoder(s.index);
498498 compute_encoder->setComputePipelineState(kernel);
499499
500500 // Kernel parameters are registered with buffer indices corresponding to
@@ -503,11 +503,11 @@ below.
503503 size_t nelem = out.size();
504504
505505 // Encode input arrays to kernel
506- set_array_buffer( compute_encoder, x, 0);
507- set_array_buffer( compute_encoder, y, 1);
506+ compute_encoder.set_input_array( x, 0);
507+ compute_encoder.set_input_array( y, 1);
508508
509509 // Encode output arrays to kernel
510- set_array_buffer( compute_encoder, out, 2);
510+ compute_encoder.set_output_array( out, 2);
511511
512512 // Encode alpha and beta
513513 compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -531,7 +531,7 @@ below.
531531
532532 // Launch the grid with the given number of threads divided among
533533 // the given threadgroups
534- compute_encoder-> dispatchThreads(grid_dims, group_dims);
534+ compute_encoder. dispatchThreads(grid_dims, group_dims);
535535 }
536536
537537We can now call the :meth: `axpby ` operation on both the CPU and the GPU!
@@ -825,7 +825,7 @@ Let's look at a simple script and its results:
825825
826826 print (f " c shape: { c.shape} " )
827827 print (f " c dtype: { c.dtype} " )
828- print (f " c correctness : { mx.all(c == 6.0 ).item()} " )
828+ print (f " c correct : { mx.all(c == 6.0 ).item()} " )
829829
830830 Output:
831831
0 commit comments