Skip to content

Commit 8b76571

Browse files
authored
Fix extensions (#1126)
* fix extensions * title * enable circle * fix nanobind tag * fix bug in doc * try to fix config * typo
1 parent e78a651 commit 8b76571

File tree

7 files changed

+36
-26
lines changed

7 files changed

+36
-26
lines changed

.circleci/config.yml

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ jobs:
4949
name: Run Python tests
5050
command: |
5151
python3 -m unittest discover python/tests -v
52-
# TODO: Reenable when extension api becomes stable
53-
# - run:
54-
# name: Build example extension
55-
# command: |
56-
# cd examples/extensions && python3 -m pip install .
5752
- run:
5853
name: Build CPP only
5954
command: |
@@ -101,11 +96,10 @@ jobs:
10196
source env/bin/activate
10297
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
10398
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
104-
# TODO: Reenable when extension api becomes stable
105-
# - run:
106-
# name: Build example extension
107-
# command: |
108-
# cd examples/extensions && python3.11 -m pip install .
99+
- run:
100+
name: Build example extension
101+
command: |
102+
cd examples/extensions && python3.8 -m pip install .
109103
- store_test_results:
110104
path: test-results
111105
- run:

docs/src/dev/extensions.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
Developer Documentation
2-
=======================
1+
Custom Extensions in MLX
2+
========================
33

44
You can extend MLX with custom operations on the CPU or GPU. This guide
55
explains how to do that with a simple example.
@@ -494,7 +494,7 @@ below.
494494
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
495495

496496
// Prepare to encode kernel
497-
auto compute_encoder = d.get_command_encoder(s.index);
497+
auto& compute_encoder = d.get_command_encoder(s.index);
498498
compute_encoder->setComputePipelineState(kernel);
499499

500500
// Kernel parameters are registered with buffer indices corresponding to
@@ -503,11 +503,11 @@ below.
503503
size_t nelem = out.size();
504504

505505
// Encode input arrays to kernel
506-
set_array_buffer(compute_encoder, x, 0);
507-
set_array_buffer(compute_encoder, y, 1);
506+
compute_encoder.set_input_array(x, 0);
507+
compute_encoder.set_input_array(y, 1);
508508

509509
// Encode output arrays to kernel
510-
set_array_buffer(compute_encoder, out, 2);
510+
compute_encoder.set_output_array(out, 2);
511511

512512
// Encode alpha and beta
513513
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -531,7 +531,7 @@ below.
531531

532532
// Launch the grid with the given number of threads divided among
533533
// the given threadgroups
534-
compute_encoder->dispatchThreads(grid_dims, group_dims);
534+
compute_encoder.dispatchThreads(grid_dims, group_dims);
535535
}
536536

537537
We can now call the :meth:`axpby` operation on both the CPU and the GPU!
@@ -825,7 +825,7 @@ Let's look at a simple script and its results:
825825
826826
print(f"c shape: {c.shape}")
827827
print(f"c dtype: {c.dtype}")
828-
print(f"c correctness: {mx.all(c == 6.0).item()}")
828+
print(f"c correct: {mx.all(c == 6.0).item()}")
829829
830830
Output:
831831

examples/extensions/README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
## Build the extensions
2+
## Build
33

44
```
55
pip install -e .
@@ -16,3 +16,9 @@ And then run:
1616
```
1717
python setup.py build_ext -j8 --inplace
1818
```
19+
20+
## Test
21+
22+
```
23+
python test.py
24+
`

examples/extensions/axpby/axpby.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ void Axpby::eval_gpu(
257257
auto kernel = d.get_kernel(kname.str(), "mlx_ext");
258258

259259
// Prepare to encode kernel
260-
auto compute_encoder = d.get_command_encoder(s.index);
260+
auto& compute_encoder = d.get_command_encoder(s.index);
261261
compute_encoder->setComputePipelineState(kernel);
262262

263263
// Kernel parameters are registered with buffer indices corresponding to
@@ -266,11 +266,11 @@ void Axpby::eval_gpu(
266266
size_t nelem = out.size();
267267

268268
// Encode input arrays to kernel
269-
set_array_buffer(compute_encoder, x, 0);
270-
set_array_buffer(compute_encoder, y, 1);
269+
compute_encoder.set_input_array(x, 0);
270+
compute_encoder.set_input_array(y, 1);
271271

272272
// Encode output arrays to kernel
273-
set_array_buffer(compute_encoder, out, 2);
273+
compute_encoder.set_output_array(out, 2);
274274

275275
// Encode alpha and beta
276276
compute_encoder->setBytes(&alpha_, sizeof(float), 3);
@@ -296,7 +296,7 @@ void Axpby::eval_gpu(
296296

297297
// Launch the grid with the given number of threads divided among
298298
// the given threadgroups
299-
compute_encoder->dispatchThreads(grid_dims, group_dims);
299+
compute_encoder.dispatchThreads(grid_dims, group_dims);
300300
}
301301

302302
#else // Metal is not available

examples/extensions/mlx_sample_extensions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22

33
import mlx.core as mx
44

5-
from .mlx_sample_extensions import *
5+
from ._ext import axpby
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
setuptools>=42
22
cmake>=3.24
33
mlx>=0.9.0
4-
nanobind@git+https://github.com/wjakob/nanobind.git#egg=4148debcf91f5ccab0c3b8d67b5c3cabd61f407f
4+
nanobind@git+https://github.com/wjakob/nanobind.git@2f04eac452a6d9142dedb957701bdb20125561e4

examples/extensions/test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import mlx.core as mx
2+
from mlx_sample_extensions import axpby
3+
4+
a = mx.ones((3, 4))
5+
b = mx.ones((3, 4))
6+
c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
7+
8+
print(f"c shape: {c.shape}")
9+
print(f"c dtype: {c.dtype}")
10+
print(f"c correct: {mx.all(c == 6.0).item()}")

0 commit comments

Comments
 (0)