-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Expand file tree
/
Copy pathsuffixAutomatonKernels.cu
More file actions
416 lines (349 loc) · 13.5 KB
/
suffixAutomatonKernels.cu
File metadata and controls
416 lines (349 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* Adapted from Baseten's sa_spec library (Apache-2.0)
* https://github.com/basetenlabs/sa_spec
*/
#include <cassert>
#include "suffixAutomatonKernels.h"
#include "tensorrt_llm/common/config.h"
TRTLLM_NAMESPACE_BEGIN
namespace kernels::speculative_decoding::suffix_automaton
{
__global__ void suffixAutomatonExtendKernel(int batchSize, int draftLength, int maxSlots, size_t stateSize,
void* slotsMemory, int const* batchIndices, int* matchLenOut, int* draftTokensOut, int const* acceptedTokensIn,
int const* acceptedLensIn)
{
// Only one thread per block does the work
if (threadIdx.x > 0)
{
return;
}
int i = blockIdx.x;
if (i >= batchSize)
{
return;
}
int batchIndex = batchIndices[i];
assert(batchIndex >= 0 && batchIndex < maxSlots);
// Calculate slot pointer based on dynamic state size
uint8_t* slotMemory = static_cast<uint8_t*>(slotsMemory) + static_cast<size_t>(batchIndex) * stateSize;
SuffixAutomaton* slot = reinterpret_cast<SuffixAutomaton*>(slotMemory);
int numNewTokens = acceptedLensIn[i];
// Bounds check: numNewTokens must be in valid range to prevent out-of-bounds access
assert(numNewTokens >= 0 && numNewTokens <= draftLength + 1);
// Extend the automaton with accepted tokens
for (int j = 0; j < numNewTokens; j++)
{
slot->extend(Token(acceptedTokensIn[i * (draftLength + 1) + j]));
}
// Lookup the longest suffix match
auto result = slot->lookup();
if (result.hasValue())
{
matchLenOut[i] = result->len;
slot->getDraftTokens(&draftTokensOut[i * draftLength], draftLength, result->pos);
}
else
{
matchLenOut[i] = 0;
}
}
void invokeSuffixAutomatonExtend(SuffixAutomatonExtendParams const& params, cudaStream_t stream)
{
params.checkParams();
int batchSize = params.batchSize;
int maxSlots = params.maxSlots;
if (batchSize > maxSlots)
{
batchSize = maxSlots;
}
size_t stateSize = getSuffixAutomatonStateSize(params.maxSeqLen);
// Launch one block per sequence, one thread per block
suffixAutomatonExtendKernel<<<batchSize, 1, 0, stream>>>(batchSize, params.draftLength, maxSlots, stateSize,
params.slots, params.batchIndices, params.matchLenOut, params.draftTokensOut, params.acceptedTokensIn,
params.acceptedLensIn);
}
__global__ void suffixAutomatonExtendNgramKernel(int batchSize, int draftLength, int maxNgramSize, int maxSlots,
size_t stateSize, void* slotsMemory, int const* batchIndices, int* matchLenOut, int* draftTokensOut,
int const* acceptedTokensIn, int const* acceptedLensIn)
{
// Only one thread per block does the work
if (threadIdx.x > 0)
{
return;
}
int i = blockIdx.x;
if (i >= batchSize)
{
return;
}
int batchIndex = batchIndices[i];
assert(batchIndex >= 0 && batchIndex < maxSlots);
// Calculate slot pointer based on dynamic state size
uint8_t* slotMemory = static_cast<uint8_t*>(slotsMemory) + static_cast<size_t>(batchIndex) * stateSize;
SuffixAutomaton* slot = reinterpret_cast<SuffixAutomaton*>(slotMemory);
int numNewTokens = acceptedLensIn[i];
// Bounds check: numNewTokens must be in valid range to prevent out-of-bounds access
assert(numNewTokens >= 0 && numNewTokens <= draftLength + 1);
// Extend the automaton with accepted tokens
for (int j = 0; j < numNewTokens; j++)
{
slot->extend(Token(acceptedTokensIn[i * (draftLength + 1) + j]));
}
// Perform lookup based on maxNgramSize
SAOptional<SuffixAutomaton::LookupResult> result;
if (maxNgramSize == -1)
{
// Longest match mode
result = slot->lookup();
}
else
{
// Fixed-size ngram mode - try sizes from maxNgramSize down to 1
for (int size = maxNgramSize; size >= 1; size--)
{
result = slot->lookupFixed(size);
if (result.hasValue())
{
break;
}
}
}
if (result.hasValue())
{
matchLenOut[i] = result->len;
slot->getDraftTokens(&draftTokensOut[i * draftLength], draftLength, result->pos);
}
else
{
matchLenOut[i] = 0;
}
}
void invokeSuffixAutomatonExtendNgram(SuffixAutomatonExtendNgramParams const& params, cudaStream_t stream)
{
params.checkParams();
int batchSize = params.batchSize;
int maxSlots = params.maxSlots;
if (batchSize > maxSlots)
{
batchSize = maxSlots;
}
size_t stateSize = getSuffixAutomatonStateSize(params.maxSeqLen);
// Launch one block per sequence, one thread per block
suffixAutomatonExtendNgramKernel<<<batchSize, 1, 0, stream>>>(batchSize, params.draftLength, params.maxNgramSize,
maxSlots, stateSize, params.slots, params.batchIndices, params.matchLenOut, params.draftTokensOut,
params.acceptedTokensIn, params.acceptedLensIn);
}
// =====================================================================
// Global search kernels (cross-request pattern sharing)
// =====================================================================
// Kernel 1: Extend all SAs with accepted tokens.
// Separate kernel ensures all mutations complete before cross-SA reads.
__global__ void suffixAutomatonGlobalExtendKernel(int batchSize, int draftLength, int maxSlots, size_t stateSize,
void* slotsMemory, int const* batchIndices, int const* acceptedTokensIn, int const* acceptedLensIn)
{
int reqIdx = blockIdx.x;
if (reqIdx >= batchSize)
{
return;
}
int ownSlotIdx = batchIndices[reqIdx];
assert(ownSlotIdx >= 0 && ownSlotIdx < maxSlots);
uint8_t* slotMemory = static_cast<uint8_t*>(slotsMemory) + static_cast<size_t>(ownSlotIdx) * stateSize;
SuffixAutomaton* ownSlot = reinterpret_cast<SuffixAutomaton*>(slotMemory);
int numNewTokens = acceptedLensIn[reqIdx];
assert(numNewTokens >= 0 && numNewTokens <= draftLength + 1);
for (int j = 0; j < numNewTokens; j++)
{
ownSlot->extend(Token(acceptedTokensIn[reqIdx * (draftLength + 1) + j]));
}
}
// Per-thread match result for shared-memory parallel reduction
struct SlotMatch
{
int matchLen;
int continuationLen;
int isOwnSlot;
int slotIdx;
TextIndex pos;
};
// kMaxGlobalSuffixLen is defined in suffixAutomatonParams.h.
// With maxNgramSize == -1, longer sequences are silently truncated to that limit.
// Kernel 2: Search all active SAs in parallel, reduce to best match per request.
// All SAs are read-only (const) — launched after the extend kernel on the same stream.
__global__ void suffixAutomatonGlobalSearchKernel(int batchSize, int draftLength, int maxNgramSize, int maxSlots,
size_t stateSize, void const* slotsMemory, int const* batchIndices, int const* activeSlotMask, int* matchLenOut,
int* matchSlotOut, int* draftTokensOut)
{
extern __shared__ SlotMatch sharedMatches[];
int reqIdx = blockIdx.x;
int slotIdx = threadIdx.x;
if (reqIdx >= batchSize)
{
return;
}
int ownSlotIdx = batchIndices[reqIdx];
assert(ownSlotIdx >= 0 && ownSlotIdx < maxSlots);
// Step 1: Extract suffix from own SA into shared memory
__shared__ Token sharedSuffix[kMaxGlobalSuffixLen];
__shared__ int suffixLen;
if (slotIdx == 0)
{
uint8_t const* slotMem = static_cast<uint8_t const*>(slotsMemory) + static_cast<size_t>(ownSlotIdx) * stateSize;
SuffixAutomaton const* ownSlot = reinterpret_cast<SuffixAutomaton const*>(slotMem);
int maxSuffixLen = (maxNgramSize > 0) ? maxNgramSize : kMaxGlobalSuffixLen;
int textLen = +ownSlot->mTokens.size();
suffixLen = (maxSuffixLen < textLen) ? maxSuffixLen : textLen;
for (int i = 0; i < suffixLen; i++)
{
sharedSuffix[i] = ownSlot->mTokens.at(TextIndex(textLen - suffixLen + i));
}
}
__syncthreads();
// Step 2: Each thread searches one slot
SlotMatch myMatch = {0, 0, 0, -1, TextIndex(0)};
if (slotIdx < maxSlots && activeSlotMask[slotIdx])
{
uint8_t const* slotMem = static_cast<uint8_t const*>(slotsMemory) + static_cast<size_t>(slotIdx) * stateSize;
SuffixAutomaton const* slot = reinterpret_cast<SuffixAutomaton const*>(slotMem);
auto result = slot->lookupWithSuffix(sharedSuffix, suffixLen);
if (result.hasValue())
{
myMatch.matchLen = result->len;
myMatch.continuationLen = +slot->mTokens.size() - +result->pos;
myMatch.isOwnSlot = (slotIdx == ownSlotIdx) ? 1 : 0;
myMatch.slotIdx = slotIdx;
myMatch.pos = result->pos;
}
}
sharedMatches[slotIdx] = myMatch;
__syncthreads();
// Step 3: Parallel reduction — three-level comparison:
// 1. Prefer longer match (higher matchLen)
// 2. Among equal matchLen, prefer own slot
// 3. Among equal matchLen and same locality, prefer longer continuation
// Requires blockDim.x to be a power of 2 (guaranteed by nextPowerOf2 in the host launcher).
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1)
{
if (slotIdx < stride)
{
auto& current = sharedMatches[slotIdx];
auto& candidate = sharedMatches[slotIdx + stride];
bool replace = false;
if (candidate.matchLen > current.matchLen)
{
replace = true;
}
else if (candidate.matchLen == current.matchLen && candidate.matchLen > 0)
{
if (candidate.isOwnSlot > current.isOwnSlot)
{
replace = true;
}
else if (candidate.isOwnSlot == current.isOwnSlot
&& candidate.continuationLen > current.continuationLen)
{
replace = true;
}
}
if (replace)
{
current = candidate;
}
}
__syncthreads();
}
// Step 4: Thread 0 writes output
if (slotIdx == 0)
{
SlotMatch best = sharedMatches[0];
if (best.matchLen > 0 && best.slotIdx >= 0)
{
matchLenOut[reqIdx] = best.matchLen;
matchSlotOut[reqIdx] = best.slotIdx;
uint8_t const* slotMem
= static_cast<uint8_t const*>(slotsMemory) + static_cast<size_t>(best.slotIdx) * stateSize;
SuffixAutomaton const* slot = reinterpret_cast<SuffixAutomaton const*>(slotMem);
slot->getDraftTokens(&draftTokensOut[reqIdx * draftLength], draftLength, best.pos);
}
else
{
matchLenOut[reqIdx] = 0;
matchSlotOut[reqIdx] = -1;
}
}
}
namespace
{
int nextPowerOf2(int v)
{
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return (v < 1) ? 1 : v;
}
} // anonymous namespace
void invokeSuffixAutomatonGlobalSearch(SuffixAutomatonGlobalSearchParams const& params, cudaStream_t stream)
{
params.checkParams();
int batchSize = params.batchSize;
int maxSlots = params.maxSlots;
if (batchSize > maxSlots)
{
batchSize = maxSlots;
}
size_t stateSize = getSuffixAutomatonStateSize(params.maxSeqLen);
// Kernel 1: Extend all SAs (1 thread per block, 1 block per request)
suffixAutomatonGlobalExtendKernel<<<batchSize, 1, 0, stream>>>(batchSize, params.draftLength, maxSlots, stateSize,
params.slots, params.batchIndices, params.acceptedTokensIn, params.acceptedLensIn);
// Kernel 2: Global search + reduce (N threads per block, 1 block per request)
int threadsPerBlock = nextPowerOf2(maxSlots);
threadsPerBlock = (threadsPerBlock < 1024) ? threadsPerBlock : 1024;
size_t sharedMemSize = static_cast<size_t>(threadsPerBlock) * sizeof(SlotMatch);
suffixAutomatonGlobalSearchKernel<<<batchSize, threadsPerBlock, sharedMemSize, stream>>>(batchSize,
params.draftLength, params.maxNgramSize, maxSlots, stateSize, params.slots, params.batchIndices,
params.activeSlotMask, params.matchLenOut, params.matchSlotOut, params.draftTokensOut);
}
size_t getSuffixAutomatonStateSize(size_t maxSeqLen)
{
return SuffixAutomaton::getRequiredMemorySize(maxSeqLen);
}
void initAutomaton(void* memory, size_t maxSeqLen)
{
SuffixAutomaton* sa = reinterpret_cast<SuffixAutomaton*>(memory);
// Use placement new to construct the struct, then initialize
new (sa) SuffixAutomaton();
sa->init(memory, maxSeqLen);
}
void buildAutomatonFromTokens(SuffixAutomaton* sa, int const* tokens, int numTokens)
{
// Extend the automaton with each token
for (int i = 0; i < numTokens; i++)
{
sa->extend(Token(tokens[i]));
}
}
void relocateAutomaton(SuffixAutomaton* sa, void* oldBase, void* newBase)
{
sa->relocate(oldBase, newBase);
}
} // namespace kernels::speculative_decoding::suffix_automaton
TRTLLM_NAMESPACE_END