Skip to content

Commit f37041d

Browse files
Evandabestmeta-codesync[bot]
authored andcommitted
Extract Metal shaders to standalone .metal file with kernel wrapper class (#5167)
Summary: ## Summary Before submitting PRs for IVFFlat, IVFPQ, IVFSQ, and BinaryFlat implementations, I wanted to run this by you guys to see if this is a good approach for organizing the Metal shader code. Currently the MSL source is embedded as an inline string in the Objective C++ file, this extracts it into a standalone `.metal` file with a wrapper class, which will be the foundation for all the upcoming index type PRs. This also upgrades the top-k kernel from a single-thread insertion sort to a 256-thread parallel bitonic sort for better GPU utilization. - Extract inline MSL shader source from `MetalFlatKernels.mm` into a standalone`MetalDistance.metal` file for proper syntax highlighting, easier editing, and reviewability - Add `MetalKernels` wrapper class that loads a precompiled `.metallib`, caches pipeline states, and provides typed `encode*()` dispatch methods (previously the library was recompiled from source on every search call) - Precompile `.metal` → `.metallib` via `xcrun metal`/`xcrun metallib` in CMake, catching shader errors at build time instead of runtime - Upgrade distance kernels from naive per-element to tiled GEMM-style (32x32 tiles with shared memory) and upgrades top-k to a parallel threadgroup-based selection kernel (256-thread parallel bitonic sort) - Foundation for future index types (IVFFlat, IVFPQ, etc.) to add kernels to the same `.metal` file and methods to `MetalKernels` ## Changes - **New:** `MetalDistance.metal`, `l2_squared_matrix`, `ip_matrix`, `topk_threadgroup_{32,64,128,256}` - **New:** `MetalKernels.h/.mm` runtime shader loading, pipeline caching, typed `encode*()` dispatch methods, per-device singleton - **Modified:** `MetalFlatKernels.mm` replaced inline MSL + manual dispatch with `MetalKernels` calls - **Modified:** `CMakeLists.txt` — added `MetalKernels.mm` to build, added custom commands to precompile `.metal` → `.air` → `.metallib` ## Build and test ```bash cmake -B build \ -DFAISS_ENABLE_METAL=ON \ -DFAISS_ENABLE_GPU=OFF \ -DFAISS_ENABLE_PYTHON=OFF \ -DBUILD_TESTING=ON \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_PREFIX_PATH="$(brew --prefix libomp)" \ . cmake --build build --target faiss_metal TestMetalIndexFlat -j$(sysctl -n hw.logicalcpu) cd build && ctest -R TestMetalIndexFlat --output-on-failure ``` Pull Request resolved: #5167 Reviewed By: alibeklfc Differential Revision: D103979808 Pulled By: mnorris11 fbshipit-source-id: 9583d6767e589ad74f7aa58141cd2e9034e5698b
1 parent f323c0f commit f37041d

5 files changed

Lines changed: 486 additions & 163 deletions

File tree

faiss/gpu_metal/CMakeLists.txt

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
set(FAISS_METAL_SRC
1212
MetalResources.mm
1313
MetalIndex.mm
14+
MetalKernels.mm
1415
MetalFlatKernels.mm
1516
MetalIndexFlat.mm
1617
StandardMetalResources.mm
@@ -44,3 +45,30 @@ set_target_properties(faiss_metal PROPERTIES
4445
OBJCXX_STANDARD 17
4546
OBJCXX_STANDARD_REQUIRED ON
4647
)
48+
49+
# Pre-compile Metal shaders: .metal -> .air -> .metallib
50+
find_program(XCRUN xcrun REQUIRED)
51+
set(METAL_SOURCE ${CMAKE_CURRENT_SOURCE_DIR}/MetalDistance.metal)
52+
set(METAL_AIR ${CMAKE_CURRENT_BINARY_DIR}/MetalDistance.air)
53+
set(METAL_LIB ${CMAKE_CURRENT_BINARY_DIR}/MetalDistance.metallib)
54+
55+
add_custom_command(
56+
OUTPUT ${METAL_AIR}
57+
COMMAND ${XCRUN} -sdk macosx metal -c ${METAL_SOURCE} -o ${METAL_AIR}
58+
DEPENDS ${METAL_SOURCE}
59+
COMMENT "Compiling MetalDistance.metal -> .air"
60+
)
61+
62+
add_custom_command(
63+
OUTPUT ${METAL_LIB}
64+
COMMAND ${XCRUN} -sdk macosx metallib ${METAL_AIR} -o ${METAL_LIB}
65+
DEPENDS ${METAL_AIR}
66+
COMMENT "Linking MetalDistance.air -> .metallib"
67+
)
68+
69+
add_custom_target(metal_shaders DEPENDS ${METAL_LIB})
70+
add_dependencies(faiss_metal metal_shaders)
71+
72+
target_compile_definitions(faiss_metal PRIVATE
73+
FAISS_METALLIB_PATH="${METAL_LIB}"
74+
)
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
// @lint-ignore-every LICENSELINT
2+
/**
3+
* Copyright (c) Meta Platforms, Inc. and its affiliates.
4+
*
5+
* This source code is licensed under the MIT license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*
8+
* Metal Shading Language kernels for distance computation and top-k selection.
9+
*
10+
* Kernel organization:
11+
* - Distance kernels: l2_squared_matrix, ip_matrix (tiled GEMM-style)
12+
* - Top-k selection: topk_threadgroup_K (parallel bitonic sort, K <= 256)
13+
*/
14+
15+
#include <metal_stdlib>
16+
using namespace metal;
17+
18+
kernel void l2_squared_matrix(
19+
device const float* queries [[buffer(0)]],
20+
device const float* vectors [[buffer(1)]],
21+
device float* distances [[buffer(2)]],
22+
device const uint* params [[buffer(3)]],
23+
uint2 tgid [[threadgroup_position_in_grid]],
24+
uint2 ltid [[thread_position_in_threadgroup]]
25+
) {
26+
constexpr uint TILE_M = 32;
27+
constexpr uint TILE_N = 32;
28+
constexpr uint TILE_K = 16;
29+
constexpr uint TG_THREADS = 256;
30+
31+
uint nq = params[0], nb = params[1], d = params[2];
32+
uint row0 = tgid.y * TILE_M;
33+
uint col0 = tgid.x * TILE_N;
34+
uint ty = ltid.y, tx = ltid.x;
35+
uint tid = ty * 16 + tx;
36+
37+
float acc00 = 0.0f, acc01 = 0.0f, acc10 = 0.0f, acc11 = 0.0f;
38+
39+
threadgroup float tgQ[TILE_M * TILE_K];
40+
threadgroup float tgV[TILE_N * TILE_K];
41+
42+
for (uint dk = 0; dk < d; dk += TILE_K) {
43+
uint kLen = min(TILE_K, d - dk);
44+
45+
for (uint i = tid; i < TILE_M * TILE_K; i += TG_THREADS) {
46+
uint mr = i / TILE_K, mk = i % TILE_K;
47+
uint gRow = row0 + mr;
48+
tgQ[i] = (gRow < nq && mk < kLen) ? queries[gRow * d + dk + mk] : 0.0f;
49+
}
50+
for (uint i = tid; i < TILE_N * TILE_K; i += TG_THREADS) {
51+
uint mr = i / TILE_K, mk = i % TILE_K;
52+
uint gCol = col0 + mr;
53+
tgV[i] = (gCol < nb && mk < kLen) ? vectors[gCol * d + dk + mk] : 0.0f;
54+
}
55+
threadgroup_barrier(mem_flags::mem_threadgroup);
56+
57+
for (uint kk = 0; kk < TILE_K; kk++) {
58+
float q0 = tgQ[(ty * 2) * TILE_K + kk];
59+
float q1 = tgQ[(ty * 2 + 1) * TILE_K + kk];
60+
float v0 = tgV[(tx * 2) * TILE_K + kk];
61+
float v1 = tgV[(tx * 2 + 1) * TILE_K + kk];
62+
float d00 = q0 - v0; acc00 += d00 * d00;
63+
float d01 = q0 - v1; acc01 += d01 * d01;
64+
float d10 = q1 - v0; acc10 += d10 * d10;
65+
float d11 = q1 - v1; acc11 += d11 * d11;
66+
}
67+
threadgroup_barrier(mem_flags::mem_threadgroup);
68+
}
69+
70+
uint r0 = row0 + ty * 2, r1 = r0 + 1;
71+
uint c0 = col0 + tx * 2, c1 = c0 + 1;
72+
if (r0 < nq && c0 < nb) distances[r0 * nb + c0] = acc00;
73+
if (r0 < nq && c1 < nb) distances[r0 * nb + c1] = acc01;
74+
if (r1 < nq && c0 < nb) distances[r1 * nb + c0] = acc10;
75+
if (r1 < nq && c1 < nb) distances[r1 * nb + c1] = acc11;
76+
}
77+
78+
kernel void ip_matrix(
79+
device const float* queries [[buffer(0)]],
80+
device const float* vectors [[buffer(1)]],
81+
device float* distances [[buffer(2)]],
82+
device const uint* params [[buffer(3)]],
83+
uint2 tgid [[threadgroup_position_in_grid]],
84+
uint2 ltid [[thread_position_in_threadgroup]]
85+
) {
86+
constexpr uint TILE_M = 32;
87+
constexpr uint TILE_N = 32;
88+
constexpr uint TILE_K = 16;
89+
constexpr uint TG_THREADS = 256;
90+
91+
uint nq = params[0], nb = params[1], d = params[2];
92+
uint row0 = tgid.y * TILE_M;
93+
uint col0 = tgid.x * TILE_N;
94+
uint ty = ltid.y, tx = ltid.x;
95+
uint tid = ty * 16 + tx;
96+
97+
float acc00 = 0.0f, acc01 = 0.0f, acc10 = 0.0f, acc11 = 0.0f;
98+
99+
threadgroup float tgQ[TILE_M * TILE_K];
100+
threadgroup float tgV[TILE_N * TILE_K];
101+
102+
for (uint dk = 0; dk < d; dk += TILE_K) {
103+
uint kLen = min(TILE_K, d - dk);
104+
105+
for (uint i = tid; i < TILE_M * TILE_K; i += TG_THREADS) {
106+
uint mr = i / TILE_K, mk = i % TILE_K;
107+
uint gRow = row0 + mr;
108+
tgQ[i] = (gRow < nq && mk < kLen) ? queries[gRow * d + dk + mk] : 0.0f;
109+
}
110+
for (uint i = tid; i < TILE_N * TILE_K; i += TG_THREADS) {
111+
uint mr = i / TILE_K, mk = i % TILE_K;
112+
uint gCol = col0 + mr;
113+
tgV[i] = (gCol < nb && mk < kLen) ? vectors[gCol * d + dk + mk] : 0.0f;
114+
}
115+
threadgroup_barrier(mem_flags::mem_threadgroup);
116+
117+
for (uint kk = 0; kk < TILE_K; kk++) {
118+
float q0 = tgQ[(ty * 2) * TILE_K + kk];
119+
float q1 = tgQ[(ty * 2 + 1) * TILE_K + kk];
120+
float v0 = tgV[(tx * 2) * TILE_K + kk];
121+
float v1 = tgV[(tx * 2 + 1) * TILE_K + kk];
122+
acc00 += q0 * v0; acc01 += q0 * v1;
123+
acc10 += q1 * v0; acc11 += q1 * v1;
124+
}
125+
threadgroup_barrier(mem_flags::mem_threadgroup);
126+
}
127+
128+
uint r0 = row0 + ty * 2, r1 = r0 + 1;
129+
uint c0 = col0 + tx * 2, c1 = c0 + 1;
130+
if (r0 < nq && c0 < nb) distances[r0 * nb + c0] = acc00;
131+
if (r0 < nq && c1 < nb) distances[r0 * nb + c1] = acc01;
132+
if (r1 < nq && c0 < nb) distances[r1 * nb + c0] = acc10;
133+
if (r1 < nq && c1 < nb) distances[r1 * nb + c1] = acc11;
134+
}
135+
136+
// ============================================================
137+
// Parallel threadgroup-based top-k (bitonic sort)
138+
// One threadgroup (256 threads) per query, 4 candidates per thread = 1024.
139+
// ============================================================
140+
141+
#define TOPK_THREADGROUP_VARIANT(K) \
142+
kernel void topk_threadgroup_##K( \
143+
device const float* distances [[buffer(0)]], \
144+
device float* outDistances [[buffer(1)]], \
145+
device int* outIndices [[buffer(2)]], \
146+
device const uint* params [[buffer(3)]], \
147+
uint qi [[threadgroup_position_in_grid]], \
148+
uint tid [[thread_position_in_threadgroup]] \
149+
) { \
150+
constexpr uint TG_SIZE = 256; \
151+
constexpr uint R = 4; \
152+
constexpr uint CANDIDATES = TG_SIZE * R; \
153+
threadgroup float tgDist[CANDIDATES]; \
154+
threadgroup int tgIdx[CANDIDATES]; \
155+
uint nq = params[0], nb = params[1], k = params[2], want_min = params[3]; \
156+
if (qi >= nq || k == 0) return; \
157+
const device float* row = distances + qi * nb; \
158+
uint kk = min(k, nb); \
159+
uint K_out = min((uint)K, kk); \
160+
\
161+
float localDist[R]; \
162+
int localIdx[R]; \
163+
uint localCount = 0; \
164+
\
165+
for (uint j = tid; j < nb; j += TG_SIZE) { \
166+
float v = row[j]; \
167+
if (localCount < R) { \
168+
uint pos = localCount; \
169+
while (pos > 0 && ((want_min && v < localDist[pos-1]) || (!want_min && v > localDist[pos-1]))) { \
170+
localDist[pos] = localDist[pos-1]; \
171+
localIdx[pos] = localIdx[pos-1]; \
172+
pos--; \
173+
} \
174+
localDist[pos] = v; \
175+
localIdx[pos] = (int)j; \
176+
localCount++; \
177+
} else { \
178+
bool better = want_min ? (v < localDist[R-1]) : (v > localDist[R-1]); \
179+
if (better) { \
180+
uint pos = R - 1; \
181+
while (pos > 0 && ((want_min && v < localDist[pos-1]) || (!want_min && v > localDist[pos-1]))) { \
182+
localDist[pos] = localDist[pos-1]; \
183+
localIdx[pos] = localIdx[pos-1]; \
184+
pos--; \
185+
} \
186+
localDist[pos] = v; \
187+
localIdx[pos] = (int)j; \
188+
} \
189+
} \
190+
} \
191+
\
192+
for (uint i = 0; i < R; i++) { \
193+
uint idx = tid * R + i; \
194+
if (i < localCount) { \
195+
tgDist[idx] = localDist[i]; \
196+
tgIdx[idx] = localIdx[i]; \
197+
} else { \
198+
tgDist[idx] = want_min ? 1e38f : -1e38f; \
199+
tgIdx[idx] = -1; \
200+
} \
201+
} \
202+
threadgroup_barrier(mem_flags::mem_threadgroup); \
203+
\
204+
for (uint k2 = 2; k2 <= CANDIDATES; k2 *= 2) { \
205+
for (uint j = k2 >> 1; j > 0; j >>= 1) { \
206+
for (uint idx = tid; idx < CANDIDATES; idx += TG_SIZE) { \
207+
uint partner = idx ^ j; \
208+
if (partner < CANDIDATES && partner > idx) { \
209+
bool ascending = ((idx & k2) == 0); \
210+
bool partnerBetter = want_min \
211+
? (tgDist[partner] < tgDist[idx] || (tgDist[partner] == tgDist[idx] && tgIdx[partner] < tgIdx[idx])) \
212+
: (tgDist[partner] > tgDist[idx] || (tgDist[partner] == tgDist[idx] && tgIdx[partner] < tgIdx[idx])); \
213+
bool idxBetter = want_min \
214+
? (tgDist[idx] < tgDist[partner] || (tgDist[idx] == tgDist[partner] && tgIdx[idx] < tgIdx[partner])) \
215+
: (tgDist[idx] > tgDist[partner] || (tgDist[idx] == tgDist[partner] && tgIdx[idx] < tgIdx[partner])); \
216+
bool swap = ascending ? partnerBetter : idxBetter; \
217+
if (swap) { \
218+
float td = tgDist[idx]; tgDist[idx] = tgDist[partner]; tgDist[partner] = td; \
219+
int ti = tgIdx[idx]; tgIdx[idx] = tgIdx[partner]; tgIdx[partner] = ti; \
220+
} \
221+
} \
222+
} \
223+
threadgroup_barrier(mem_flags::mem_threadgroup); \
224+
} \
225+
} \
226+
\
227+
for (uint i = tid; i < K_out; i += TG_SIZE) { \
228+
outDistances[qi * k + i] = tgDist[i]; \
229+
outIndices[qi * k + i] = tgIdx[i]; \
230+
} \
231+
for (uint i = tid; i < k - K_out; i += TG_SIZE) { \
232+
outDistances[qi * k + K_out + i] = want_min ? 1e38f : -1e38f; \
233+
outIndices[qi * k + K_out + i] = -1; \
234+
} \
235+
}
236+
237+
TOPK_THREADGROUP_VARIANT(32)
238+
TOPK_THREADGROUP_VARIANT(64)
239+
TOPK_THREADGROUP_VARIANT(128)
240+
TOPK_THREADGROUP_VARIANT(256)
241+
#undef TOPK_THREADGROUP_VARIANT

0 commit comments

Comments
 (0)