Skip to content

Commit 84cd148

Browse files
committed
add build script for metal kernels (from mlx)
Signed-off-by: Alex Chi <[email protected]>
1 parent 46427f7 commit 84cd148

25 files changed

+1137
-11
lines changed

book/src/week1-overview.md

+4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ We will use the Qwen2-7B-Instruct model for this week. As we need to dequantize
88
20GB of memory in week 1. If you do not have enough memory, you can consider using the smaller 0.5B model (we do not have
99
infra to test it so you need to figure out things on your own unfortunately).
1010

11+
The MLX version of the Qwen2-7B-Instruct model we downloaded in the setup is an int4 quantized version of the original bfloat16 model.
12+
1113
## What We will Cover
1214

1315
* Attention, Multi-Head Attention, and Grouped/Multi Query Attention
@@ -44,5 +46,7 @@ utilize these resources to better understand the internals of the model and what
4446
- [Huggingface Transformers - Qwen2](https://github.com/huggingface/transformers/tree/main/src/transformers/models/qwen2)
4547
- [vLLM Qwen2](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2.py)
4648
- [mlx-lm Qwen2](https://github.com/ml-explore/mlx-lm/blob/main/mlx_lm/models/qwen2.py)
49+
- [Qwen2 Technical Report](https://arxiv.org/pdf/2407.10671)
50+
- [Qwen2.5 Technical Report](https://arxiv.org/pdf/2412.15115)
4751

4852
{{#include copyright.md}}

book/src/week2-overview.md

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
https://github.com/ml-explore/mlx/blob/main/mlx/backend/cpu/quantized.cpp
2+
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/linear.py
3+
MLX uses INT4 W4A16
4+
https://ml-explore.github.io/mlx/build/html/dev/extensions.html

build-extension.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from __future__ import annotations
2+
3+
from pathlib import Path
4+
from setuptools import Distribution
5+
from mlx import extension
6+
import shutil
7+
8+
9+
def build():
10+
src_dir = Path(__file__).parent.joinpath("src").joinpath("extensions_ref")
11+
ext_modules = [extension.CMakeExtension("tiny_llm_ext_ref._ext", src_dir)]
12+
distribution = Distribution(
13+
{
14+
"name": "tiny_llm_ext_ref",
15+
"ext_modules": ext_modules,
16+
}
17+
)
18+
cmd = extension.CMakeBuild(distribution)
19+
cmd.ensure_finalized()
20+
cmd.run()
21+
for output in cmd.get_outputs():
22+
output = Path(output)
23+
relative_extension = src_dir / output.relative_to(cmd.build_lib)
24+
shutil.copyfile(output, relative_extension)
25+
26+
27+
if __name__ == "__main__":
28+
build()

pyproject.toml

+10-1
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,18 @@ numpy = "^2.2.4"
2020
ruff = "^0.11.6"
2121

2222
[build-system]
23-
requires = ["poetry-core"]
23+
requires = [
24+
"poetry-core",
25+
"setuptools>=42",
26+
"cmake>=3.25",
27+
"mlx>=0.18.0",
28+
"nanobind==2.4.0"
29+
]
2430
build-backend = "poetry.core.masonry.api"
2531

32+
[tool.poetry.build]
33+
script = "build-extension.py"
34+
2635
[project]
2736
name = "tiny-llm"
2837
version = "0.1.0"

src/extensions_ref/CMakeLists.txt

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
cmake_minimum_required(VERSION 3.27)
2+
3+
project(_ext LANGUAGES CXX)
4+
5+
# ----------------------------- Setup -----------------------------
6+
set(CMAKE_CXX_STANDARD 17)
7+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
8+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
9+
10+
option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON)
11+
12+
# ----------------------------- Dependencies -----------------------------
13+
find_package(
14+
Python 3.8
15+
COMPONENTS Interpreter Development.Module
16+
REQUIRED)
17+
execute_process(
18+
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
19+
OUTPUT_STRIP_TRAILING_WHITESPACE
20+
OUTPUT_VARIABLE nanobind_ROOT)
21+
find_package(nanobind CONFIG REQUIRED)
22+
23+
execute_process(
24+
COMMAND "${Python_EXECUTABLE}" -m mlx --cmake-dir
25+
OUTPUT_STRIP_TRAILING_WHITESPACE
26+
OUTPUT_VARIABLE MLX_ROOT)
27+
find_package(MLX CONFIG REQUIRED)
28+
29+
# ----------------------------- Extensions -----------------------------
30+
31+
# Add library
32+
add_library(tiny_llm_ext_ref)
33+
34+
# Add sources
35+
target_sources(tiny_llm_ext_ref PUBLIC ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp)
36+
37+
# Add include headers
38+
target_include_directories(tiny_llm_ext_ref PUBLIC ${CMAKE_CURRENT_LIST_DIR})
39+
40+
# Link to mlx
41+
target_link_libraries(tiny_llm_ext_ref PUBLIC mlx)
42+
43+
44+
# ----------------------------- Metal -----------------------------
45+
46+
# Build metallib
47+
if(MLX_BUILD_METAL)
48+
mlx_build_metallib(
49+
TARGET
50+
tiny_llm_ext_ref_metallib
51+
TITLE
52+
tiny_llm_ext_ref
53+
SOURCES
54+
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
55+
INCLUDE_DIRS
56+
${PROJECT_SOURCE_DIR}
57+
${MLX_INCLUDE_DIRS}
58+
OUTPUT_DIRECTORY
59+
${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
60+
61+
add_dependencies(tiny_llm_ext_ref tiny_llm_ext_ref_metallib)
62+
endif()
63+
64+
# ----------------------------- Python Bindings -----------------------------
65+
nanobind_add_module(
66+
_ext
67+
NB_STATIC
68+
STABLE_ABI
69+
LTO
70+
NOMINSIZE
71+
NB_DOMAIN
72+
mlx
73+
${CMAKE_CURRENT_LIST_DIR}/bindings.cpp)
74+
target_link_libraries(_ext PRIVATE tiny_llm_ext_ref)
75+
76+
if(BUILD_SHARED_LIBS)
77+
target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
78+
endif()

0 commit comments

Comments
 (0)