Skip to content

Commit b39ebc9

Browse files
committed
add def of flash_attention
Signed-off-by: Alex Chi Z <[email protected]>
1 parent a20bb45 commit b39ebc9

12 files changed

+46
-413
lines changed

book/src/week2-overview.md

+4
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ speculative decoding
1111
prefill and decode separation
1212
quantized kv cache
1313
Assert return data type
14+
15+
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/sdpa_vector.h
16+
https://github.com/philipturner/metal-flash-attention
17+
https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h

build_ext.sh

-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33
set -e
44
pdm run build-ext-ref
55
cp src/extensions_ref/build/lib/tiny_llm_ext_ref/tiny_llm_ext_ref.metallib .venv/lib/python3.12/site-packages/mlx/lib/
6-
pdm run test-week2-ref -k 'week_2_day_2'

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ test.cmd = "pytest tests"
2929
test-week1-ref.cmd = "pytest tests_ref_impl_week1"
3030
test-week2-ref.cmd = "pytest tests_ref_impl_week2"
3131
format = "ruff format"
32+
format-cpp.shell = "find src/extensions_ref -type file \\( -name '*.h' -or -name '*.cpp' \\) | xargs -n1 clang-format -i"
3233

3334
[tool.pytest.ini_options]
3435
addopts = [

src/extensions_ref/CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ add_library(tiny_llm_ext_ref)
3535
target_sources(
3636
tiny_llm_ext_ref
3737
PUBLIC
38-
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
3938
${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.cpp
39+
${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.cpp
4040
)
4141

4242
# Add include headers
@@ -57,8 +57,8 @@ if(MLX_BUILD_METAL)
5757
TITLE
5858
tiny_llm_ext_ref
5959
SOURCES
60-
${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
6160
${CMAKE_CURRENT_LIST_DIR}/src/quantized_matmul.metal
61+
${CMAKE_CURRENT_LIST_DIR}/src/flash_attention.metal
6262
INCLUDE_DIRS
6363
${PROJECT_SOURCE_DIR}
6464
${MLX_INCLUDE_DIRS}

src/extensions_ref/axpby/axpby.cpp

-260
This file was deleted.

src/extensions_ref/axpby/axpby.h

-76
This file was deleted.

0 commit comments

Comments
 (0)