|
| 1 | +// @lint-ignore-every LICENSELINT |
| 2 | +/** |
| 3 | + * Copyright (c) Meta Platforms, Inc. and its affiliates. |
| 4 | + * |
| 5 | + * This source code is licensed under the MIT license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + * |
| 8 | + * Metal Shading Language kernels for distance computation and top-k selection. |
| 9 | + * |
| 10 | + * Kernel organization: |
| 11 | + * - Distance kernels: l2_squared_matrix, ip_matrix (tiled GEMM-style) |
| 12 | + * - Top-k selection: topk_threadgroup_K (parallel bitonic sort, K <= 256) |
| 13 | + */ |
| 14 | + |
| 15 | +#include <metal_stdlib> |
| 16 | +using namespace metal; |
| 17 | + |
| 18 | +kernel void l2_squared_matrix( |
| 19 | + device const float* queries [[buffer(0)]], |
| 20 | + device const float* vectors [[buffer(1)]], |
| 21 | + device float* distances [[buffer(2)]], |
| 22 | + device const uint* params [[buffer(3)]], |
| 23 | + uint2 tgid [[threadgroup_position_in_grid]], |
| 24 | + uint2 ltid [[thread_position_in_threadgroup]] |
| 25 | +) { |
| 26 | + constexpr uint TILE_M = 32; |
| 27 | + constexpr uint TILE_N = 32; |
| 28 | + constexpr uint TILE_K = 16; |
| 29 | + constexpr uint TG_THREADS = 256; |
| 30 | + |
| 31 | + uint nq = params[0], nb = params[1], d = params[2]; |
| 32 | + uint row0 = tgid.y * TILE_M; |
| 33 | + uint col0 = tgid.x * TILE_N; |
| 34 | + uint ty = ltid.y, tx = ltid.x; |
| 35 | + uint tid = ty * 16 + tx; |
| 36 | + |
| 37 | + float acc00 = 0.0f, acc01 = 0.0f, acc10 = 0.0f, acc11 = 0.0f; |
| 38 | + |
| 39 | + threadgroup float tgQ[TILE_M * TILE_K]; |
| 40 | + threadgroup float tgV[TILE_N * TILE_K]; |
| 41 | + |
| 42 | + for (uint dk = 0; dk < d; dk += TILE_K) { |
| 43 | + uint kLen = min(TILE_K, d - dk); |
| 44 | + |
| 45 | + for (uint i = tid; i < TILE_M * TILE_K; i += TG_THREADS) { |
| 46 | + uint mr = i / TILE_K, mk = i % TILE_K; |
| 47 | + uint gRow = row0 + mr; |
| 48 | + tgQ[i] = (gRow < nq && mk < kLen) ? queries[gRow * d + dk + mk] : 0.0f; |
| 49 | + } |
| 50 | + for (uint i = tid; i < TILE_N * TILE_K; i += TG_THREADS) { |
| 51 | + uint mr = i / TILE_K, mk = i % TILE_K; |
| 52 | + uint gCol = col0 + mr; |
| 53 | + tgV[i] = (gCol < nb && mk < kLen) ? vectors[gCol * d + dk + mk] : 0.0f; |
| 54 | + } |
| 55 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 56 | + |
| 57 | + for (uint kk = 0; kk < TILE_K; kk++) { |
| 58 | + float q0 = tgQ[(ty * 2) * TILE_K + kk]; |
| 59 | + float q1 = tgQ[(ty * 2 + 1) * TILE_K + kk]; |
| 60 | + float v0 = tgV[(tx * 2) * TILE_K + kk]; |
| 61 | + float v1 = tgV[(tx * 2 + 1) * TILE_K + kk]; |
| 62 | + float d00 = q0 - v0; acc00 += d00 * d00; |
| 63 | + float d01 = q0 - v1; acc01 += d01 * d01; |
| 64 | + float d10 = q1 - v0; acc10 += d10 * d10; |
| 65 | + float d11 = q1 - v1; acc11 += d11 * d11; |
| 66 | + } |
| 67 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 68 | + } |
| 69 | + |
| 70 | + uint r0 = row0 + ty * 2, r1 = r0 + 1; |
| 71 | + uint c0 = col0 + tx * 2, c1 = c0 + 1; |
| 72 | + if (r0 < nq && c0 < nb) distances[r0 * nb + c0] = acc00; |
| 73 | + if (r0 < nq && c1 < nb) distances[r0 * nb + c1] = acc01; |
| 74 | + if (r1 < nq && c0 < nb) distances[r1 * nb + c0] = acc10; |
| 75 | + if (r1 < nq && c1 < nb) distances[r1 * nb + c1] = acc11; |
| 76 | +} |
| 77 | + |
| 78 | +kernel void ip_matrix( |
| 79 | + device const float* queries [[buffer(0)]], |
| 80 | + device const float* vectors [[buffer(1)]], |
| 81 | + device float* distances [[buffer(2)]], |
| 82 | + device const uint* params [[buffer(3)]], |
| 83 | + uint2 tgid [[threadgroup_position_in_grid]], |
| 84 | + uint2 ltid [[thread_position_in_threadgroup]] |
| 85 | +) { |
| 86 | + constexpr uint TILE_M = 32; |
| 87 | + constexpr uint TILE_N = 32; |
| 88 | + constexpr uint TILE_K = 16; |
| 89 | + constexpr uint TG_THREADS = 256; |
| 90 | + |
| 91 | + uint nq = params[0], nb = params[1], d = params[2]; |
| 92 | + uint row0 = tgid.y * TILE_M; |
| 93 | + uint col0 = tgid.x * TILE_N; |
| 94 | + uint ty = ltid.y, tx = ltid.x; |
| 95 | + uint tid = ty * 16 + tx; |
| 96 | + |
| 97 | + float acc00 = 0.0f, acc01 = 0.0f, acc10 = 0.0f, acc11 = 0.0f; |
| 98 | + |
| 99 | + threadgroup float tgQ[TILE_M * TILE_K]; |
| 100 | + threadgroup float tgV[TILE_N * TILE_K]; |
| 101 | + |
| 102 | + for (uint dk = 0; dk < d; dk += TILE_K) { |
| 103 | + uint kLen = min(TILE_K, d - dk); |
| 104 | + |
| 105 | + for (uint i = tid; i < TILE_M * TILE_K; i += TG_THREADS) { |
| 106 | + uint mr = i / TILE_K, mk = i % TILE_K; |
| 107 | + uint gRow = row0 + mr; |
| 108 | + tgQ[i] = (gRow < nq && mk < kLen) ? queries[gRow * d + dk + mk] : 0.0f; |
| 109 | + } |
| 110 | + for (uint i = tid; i < TILE_N * TILE_K; i += TG_THREADS) { |
| 111 | + uint mr = i / TILE_K, mk = i % TILE_K; |
| 112 | + uint gCol = col0 + mr; |
| 113 | + tgV[i] = (gCol < nb && mk < kLen) ? vectors[gCol * d + dk + mk] : 0.0f; |
| 114 | + } |
| 115 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 116 | + |
| 117 | + for (uint kk = 0; kk < TILE_K; kk++) { |
| 118 | + float q0 = tgQ[(ty * 2) * TILE_K + kk]; |
| 119 | + float q1 = tgQ[(ty * 2 + 1) * TILE_K + kk]; |
| 120 | + float v0 = tgV[(tx * 2) * TILE_K + kk]; |
| 121 | + float v1 = tgV[(tx * 2 + 1) * TILE_K + kk]; |
| 122 | + acc00 += q0 * v0; acc01 += q0 * v1; |
| 123 | + acc10 += q1 * v0; acc11 += q1 * v1; |
| 124 | + } |
| 125 | + threadgroup_barrier(mem_flags::mem_threadgroup); |
| 126 | + } |
| 127 | + |
| 128 | + uint r0 = row0 + ty * 2, r1 = r0 + 1; |
| 129 | + uint c0 = col0 + tx * 2, c1 = c0 + 1; |
| 130 | + if (r0 < nq && c0 < nb) distances[r0 * nb + c0] = acc00; |
| 131 | + if (r0 < nq && c1 < nb) distances[r0 * nb + c1] = acc01; |
| 132 | + if (r1 < nq && c0 < nb) distances[r1 * nb + c0] = acc10; |
| 133 | + if (r1 < nq && c1 < nb) distances[r1 * nb + c1] = acc11; |
| 134 | +} |
| 135 | + |
| 136 | +// ============================================================ |
| 137 | +// Parallel threadgroup-based top-k (bitonic sort) |
| 138 | +// One threadgroup (256 threads) per query, 4 candidates per thread = 1024. |
| 139 | +// ============================================================ |
| 140 | + |
| 141 | +#define TOPK_THREADGROUP_VARIANT(K) \ |
| 142 | +kernel void topk_threadgroup_##K( \ |
| 143 | + device const float* distances [[buffer(0)]], \ |
| 144 | + device float* outDistances [[buffer(1)]], \ |
| 145 | + device int* outIndices [[buffer(2)]], \ |
| 146 | + device const uint* params [[buffer(3)]], \ |
| 147 | + uint qi [[threadgroup_position_in_grid]], \ |
| 148 | + uint tid [[thread_position_in_threadgroup]] \ |
| 149 | +) { \ |
| 150 | + constexpr uint TG_SIZE = 256; \ |
| 151 | + constexpr uint R = 4; \ |
| 152 | + constexpr uint CANDIDATES = TG_SIZE * R; \ |
| 153 | + threadgroup float tgDist[CANDIDATES]; \ |
| 154 | + threadgroup int tgIdx[CANDIDATES]; \ |
| 155 | + uint nq = params[0], nb = params[1], k = params[2], want_min = params[3]; \ |
| 156 | + if (qi >= nq || k == 0) return; \ |
| 157 | + const device float* row = distances + qi * nb; \ |
| 158 | + uint kk = min(k, nb); \ |
| 159 | + uint K_out = min((uint)K, kk); \ |
| 160 | + \ |
| 161 | + float localDist[R]; \ |
| 162 | + int localIdx[R]; \ |
| 163 | + uint localCount = 0; \ |
| 164 | + \ |
| 165 | + for (uint j = tid; j < nb; j += TG_SIZE) { \ |
| 166 | + float v = row[j]; \ |
| 167 | + if (localCount < R) { \ |
| 168 | + uint pos = localCount; \ |
| 169 | + while (pos > 0 && ((want_min && v < localDist[pos-1]) || (!want_min && v > localDist[pos-1]))) { \ |
| 170 | + localDist[pos] = localDist[pos-1]; \ |
| 171 | + localIdx[pos] = localIdx[pos-1]; \ |
| 172 | + pos--; \ |
| 173 | + } \ |
| 174 | + localDist[pos] = v; \ |
| 175 | + localIdx[pos] = (int)j; \ |
| 176 | + localCount++; \ |
| 177 | + } else { \ |
| 178 | + bool better = want_min ? (v < localDist[R-1]) : (v > localDist[R-1]); \ |
| 179 | + if (better) { \ |
| 180 | + uint pos = R - 1; \ |
| 181 | + while (pos > 0 && ((want_min && v < localDist[pos-1]) || (!want_min && v > localDist[pos-1]))) { \ |
| 182 | + localDist[pos] = localDist[pos-1]; \ |
| 183 | + localIdx[pos] = localIdx[pos-1]; \ |
| 184 | + pos--; \ |
| 185 | + } \ |
| 186 | + localDist[pos] = v; \ |
| 187 | + localIdx[pos] = (int)j; \ |
| 188 | + } \ |
| 189 | + } \ |
| 190 | + } \ |
| 191 | + \ |
| 192 | + for (uint i = 0; i < R; i++) { \ |
| 193 | + uint idx = tid * R + i; \ |
| 194 | + if (i < localCount) { \ |
| 195 | + tgDist[idx] = localDist[i]; \ |
| 196 | + tgIdx[idx] = localIdx[i]; \ |
| 197 | + } else { \ |
| 198 | + tgDist[idx] = want_min ? 1e38f : -1e38f; \ |
| 199 | + tgIdx[idx] = -1; \ |
| 200 | + } \ |
| 201 | + } \ |
| 202 | + threadgroup_barrier(mem_flags::mem_threadgroup); \ |
| 203 | + \ |
| 204 | + for (uint k2 = 2; k2 <= CANDIDATES; k2 *= 2) { \ |
| 205 | + for (uint j = k2 >> 1; j > 0; j >>= 1) { \ |
| 206 | + for (uint idx = tid; idx < CANDIDATES; idx += TG_SIZE) { \ |
| 207 | + uint partner = idx ^ j; \ |
| 208 | + if (partner < CANDIDATES && partner > idx) { \ |
| 209 | + bool ascending = ((idx & k2) == 0); \ |
| 210 | + bool partnerBetter = want_min \ |
| 211 | + ? (tgDist[partner] < tgDist[idx] || (tgDist[partner] == tgDist[idx] && tgIdx[partner] < tgIdx[idx])) \ |
| 212 | + : (tgDist[partner] > tgDist[idx] || (tgDist[partner] == tgDist[idx] && tgIdx[partner] < tgIdx[idx])); \ |
| 213 | + bool idxBetter = want_min \ |
| 214 | + ? (tgDist[idx] < tgDist[partner] || (tgDist[idx] == tgDist[partner] && tgIdx[idx] < tgIdx[partner])) \ |
| 215 | + : (tgDist[idx] > tgDist[partner] || (tgDist[idx] == tgDist[partner] && tgIdx[idx] < tgIdx[partner])); \ |
| 216 | + bool swap = ascending ? partnerBetter : idxBetter; \ |
| 217 | + if (swap) { \ |
| 218 | + float td = tgDist[idx]; tgDist[idx] = tgDist[partner]; tgDist[partner] = td; \ |
| 219 | + int ti = tgIdx[idx]; tgIdx[idx] = tgIdx[partner]; tgIdx[partner] = ti; \ |
| 220 | + } \ |
| 221 | + } \ |
| 222 | + } \ |
| 223 | + threadgroup_barrier(mem_flags::mem_threadgroup); \ |
| 224 | + } \ |
| 225 | + } \ |
| 226 | + \ |
| 227 | + for (uint i = tid; i < K_out; i += TG_SIZE) { \ |
| 228 | + outDistances[qi * k + i] = tgDist[i]; \ |
| 229 | + outIndices[qi * k + i] = tgIdx[i]; \ |
| 230 | + } \ |
| 231 | + for (uint i = tid; i < k - K_out; i += TG_SIZE) { \ |
| 232 | + outDistances[qi * k + K_out + i] = want_min ? 1e38f : -1e38f; \ |
| 233 | + outIndices[qi * k + K_out + i] = -1; \ |
| 234 | + } \ |
| 235 | +} |
| 236 | + |
| 237 | +TOPK_THREADGROUP_VARIANT(32) |
| 238 | +TOPK_THREADGROUP_VARIANT(64) |
| 239 | +TOPK_THREADGROUP_VARIANT(128) |
| 240 | +TOPK_THREADGROUP_VARIANT(256) |
| 241 | +#undef TOPK_THREADGROUP_VARIANT |
0 commit comments