Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -375,6 +375,84 @@ struct SuffixAutomaton
return SAOptional<LookupResult>();
}

/**
* @brief Find the longest suffix of an external token sequence that appears
* as a substring in this SA's text, then return its continuation position.
*
* Uses the standard longest-common-substring algorithm: process suffix tokens
* in forward order through the SA, following suffix links on mismatch.
*
* Time complexity: O(suffixLen) amortized. Each token either advances the
* match or triggers suffix link fallbacks. Since matchedLen increases at most
* suffixLen times and never goes below 0, total suffix link hops is bounded
* by suffixLen.
*
* @param suffix Pointer to the suffix tokens (forward order: oldest to newest)
* @param suffixLen Number of tokens in the suffix
* @return Optional LookupResult with continuation position and match length
*/
SA_CUDA_CALLABLE SAOptional<LookupResult> lookupWithSuffix(Token const* suffix, int suffixLen) const
{
if (mStates.empty() || suffixLen <= 0)
{
return SAOptional<LookupResult>();
}

NodeIndex state = NodeIndex(0);
int matchedLen = 0;

for (int i = 0; i < suffixLen; i++)
{
Token token = suffix[i];

while (state != NodeIndex(0) && mStates.at(state, token) == nullptr)
{
state = *mStates.at(state).link;
matchedLen = mStates.at(state).len;
}

NodeIndex const* nextPtr = mStates.at(state, token);
if (nextPtr != nullptr)
{
state = *nextPtr;
matchedLen++;
}
}

if (matchedLen == 0 || state == NodeIndex(0))
{
return SAOptional<LookupResult>();
}

while (state != NodeIndex(0))
{
auto& nodeData = mStates.at(state);
SAOptional<TextIndex> posOpt = nodeData.pos;

if (posOpt.hasValue())
{
TextIndex pos = *posOpt;
if (+pos + 1 < +mTokens.size())
{
LookupResult result;
result.pos = TextIndex(+pos + 1);
result.len = matchedLen;
return SAOptional<LookupResult>(result);
}
}

auto linkOpt = nodeData.link;
if (!linkOpt.hasValue())
{
break;
}
state = *linkOpt;
matchedLen = mStates.at(state).len;
}

return SAOptional<LookupResult>();
}

SA_CUDA_CALLABLE void getDraftTokens(Token::ValueType* buf, int bufLen, TextIndex startPos) const
{
int availableLen = +mTokens.size() - +startPos;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,213 @@ void invokeSuffixAutomatonExtendNgram(SuffixAutomatonExtendNgramParams const& pa
params.acceptedTokensIn, params.acceptedLensIn);
}

// =====================================================================
// Global search kernels (cross-request pattern sharing)
// =====================================================================

// Kernel 1: Extend all SAs with accepted tokens.
// Separate kernel ensures all mutations complete before cross-SA reads.
__global__ void suffixAutomatonGlobalExtendKernel(int batchSize, int draftLength, int maxSlots, size_t stateSize,
void* slotsMemory, int const* batchIndices, int const* acceptedTokensIn, int const* acceptedLensIn)
{
int reqIdx = blockIdx.x;
if (reqIdx >= batchSize)
{
return;
}

int ownSlotIdx = batchIndices[reqIdx];
assert(ownSlotIdx >= 0 && ownSlotIdx < maxSlots);
uint8_t* slotMemory = static_cast<uint8_t*>(slotsMemory) + static_cast<size_t>(ownSlotIdx) * stateSize;
SuffixAutomaton* ownSlot = reinterpret_cast<SuffixAutomaton*>(slotMemory);

int numNewTokens = acceptedLensIn[reqIdx];
assert(numNewTokens >= 0 && numNewTokens <= draftLength + 1);

for (int j = 0; j < numNewTokens; j++)
{
ownSlot->extend(Token(acceptedTokensIn[reqIdx * (draftLength + 1) + j]));
}
}

// Per-thread match result for shared-memory parallel reduction
struct SlotMatch
{
int matchLen;
int continuationLen;
int isOwnSlot;
int slotIdx;
TextIndex pos;
};

// kMaxGlobalSuffixLen is defined in suffixAutomatonParams.h.
// With maxNgramSize == -1, longer sequences are silently truncated to that limit.

// Kernel 2: Search all active SAs in parallel, reduce to best match per request.
// All SAs are read-only (const) — launched after the extend kernel on the same stream.
__global__ void suffixAutomatonGlobalSearchKernel(int batchSize, int draftLength, int maxNgramSize, int maxSlots,
size_t stateSize, void const* slotsMemory, int const* batchIndices, int const* activeSlotMask, int* matchLenOut,
int* matchSlotOut, int* draftTokensOut)
{
extern __shared__ SlotMatch sharedMatches[];

int reqIdx = blockIdx.x;
int slotIdx = threadIdx.x;

if (reqIdx >= batchSize)
{
return;
}

int ownSlotIdx = batchIndices[reqIdx];
assert(ownSlotIdx >= 0 && ownSlotIdx < maxSlots);

// Step 1: Extract suffix from own SA into shared memory
__shared__ Token sharedSuffix[kMaxGlobalSuffixLen];
__shared__ int suffixLen;

if (slotIdx == 0)
{
uint8_t const* slotMem = static_cast<uint8_t const*>(slotsMemory) + static_cast<size_t>(ownSlotIdx) * stateSize;
SuffixAutomaton const* ownSlot = reinterpret_cast<SuffixAutomaton const*>(slotMem);

int maxSuffixLen = (maxNgramSize > 0) ? maxNgramSize : kMaxGlobalSuffixLen;
int textLen = +ownSlot->mTokens.size();
suffixLen = (maxSuffixLen < textLen) ? maxSuffixLen : textLen;

for (int i = 0; i < suffixLen; i++)
{
sharedSuffix[i] = ownSlot->mTokens.at(TextIndex(textLen - suffixLen + i));
}
}
__syncthreads();

// Step 2: Each thread searches one slot
SlotMatch myMatch = {0, 0, 0, -1, TextIndex(0)};

if (slotIdx < maxSlots && activeSlotMask[slotIdx])
{
uint8_t const* slotMem = static_cast<uint8_t const*>(slotsMemory) + static_cast<size_t>(slotIdx) * stateSize;
SuffixAutomaton const* slot = reinterpret_cast<SuffixAutomaton const*>(slotMem);

auto result = slot->lookupWithSuffix(sharedSuffix, suffixLen);
if (result.hasValue())
{
myMatch.matchLen = result->len;
myMatch.continuationLen = +slot->mTokens.size() - +result->pos;
myMatch.isOwnSlot = (slotIdx == ownSlotIdx) ? 1 : 0;
myMatch.slotIdx = slotIdx;
myMatch.pos = result->pos;
}
}

sharedMatches[slotIdx] = myMatch;
__syncthreads();

// Step 3: Parallel reduction — three-level comparison:
// 1. Prefer longer match (higher matchLen)
// 2. Among equal matchLen, prefer own slot
// 3. Among equal matchLen and same locality, prefer longer continuation
// Requires blockDim.x to be a power of 2 (guaranteed by nextPowerOf2 in the host launcher).
for (int stride = blockDim.x / 2; stride > 0; stride >>= 1)
{
if (slotIdx < stride)
{
auto& current = sharedMatches[slotIdx];
auto& candidate = sharedMatches[slotIdx + stride];
bool replace = false;
if (candidate.matchLen > current.matchLen)
{
replace = true;
}
else if (candidate.matchLen == current.matchLen && candidate.matchLen > 0)
{
if (candidate.isOwnSlot > current.isOwnSlot)
{
replace = true;
}
else if (candidate.isOwnSlot == current.isOwnSlot
&& candidate.continuationLen > current.continuationLen)
{
replace = true;
}
}
if (replace)
{
current = candidate;
}
}
__syncthreads();
}

// Step 4: Thread 0 writes output
if (slotIdx == 0)
{
SlotMatch best = sharedMatches[0];

if (best.matchLen > 0 && best.slotIdx >= 0)
{
matchLenOut[reqIdx] = best.matchLen;
matchSlotOut[reqIdx] = best.slotIdx;

uint8_t const* slotMem
= static_cast<uint8_t const*>(slotsMemory) + static_cast<size_t>(best.slotIdx) * stateSize;
SuffixAutomaton const* slot = reinterpret_cast<SuffixAutomaton const*>(slotMem);
slot->getDraftTokens(&draftTokensOut[reqIdx * draftLength], draftLength, best.pos);
}
else
{
matchLenOut[reqIdx] = 0;
matchSlotOut[reqIdx] = -1;
Comment on lines +336 to +337
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In no match branch, it doesn't clear out the draftTokensOut, which could potentially cause the draft tokens to be stale.

In sa_worker.py:

 if sa_manager.enable_global_pool:
            match_len, draft_tokens = sa_manager.extend_global(...)
        else:
            match_len, draft_tokens = sa_manager.extend_ngram(...)
        return draft_tokens

Is it possible that it will directly return the stale draft tokens without any gating?

}
}
}

namespace
{

int nextPowerOf2(int v)
{
v--;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v++;
return (v < 1) ? 1 : v;
}

} // anonymous namespace

void invokeSuffixAutomatonGlobalSearch(SuffixAutomatonGlobalSearchParams const& params, cudaStream_t stream)
{
params.checkParams();

int batchSize = params.batchSize;
int maxSlots = params.maxSlots;
if (batchSize > maxSlots)
{
batchSize = maxSlots;
}

size_t stateSize = getSuffixAutomatonStateSize(params.maxSeqLen);

// Kernel 1: Extend all SAs (1 thread per block, 1 block per request)
suffixAutomatonGlobalExtendKernel<<<batchSize, 1, 0, stream>>>(batchSize, params.draftLength, maxSlots, stateSize,
params.slots, params.batchIndices, params.acceptedTokensIn, params.acceptedLensIn);

// Kernel 2: Global search + reduce (N threads per block, 1 block per request)
int threadsPerBlock = nextPowerOf2(maxSlots);
threadsPerBlock = (threadsPerBlock < 1024) ? threadsPerBlock : 1024;

size_t sharedMemSize = static_cast<size_t>(threadsPerBlock) * sizeof(SlotMatch);

suffixAutomatonGlobalSearchKernel<<<batchSize, threadsPerBlock, sharedMemSize, stream>>>(batchSize,
params.draftLength, params.maxNgramSize, maxSlots, stateSize, params.slots, params.batchIndices,
params.activeSlotMask, params.matchLenOut, params.matchSlotOut, params.draftTokensOut);
}

size_t getSuffixAutomatonStateSize(size_t maxSeqLen)
{
return SuffixAutomaton::getRequiredMemorySize(maxSeqLen);
Expand Down
Loading
Loading