Skip to content

Commit a561198

Browse files
authored
Merge branch 'main' into user/xiy/6kd_kernel
2 parents c1f4c29 + d3536f1 commit a561198

File tree

416 files changed

+17616
-6020
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

416 files changed

+17616
-6020
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,6 +1030,7 @@ common-files: &common_files |
10301030
tests/unittest/_torch/speculative/test_kv_cache_reuse.py |
10311031
tests/unittest/_torch/speculative/test_mtp.py |
10321032
tests/unittest/_torch/speculative/test_ngram.py |
1033+
tests/unittest/_torch/speculative/test_sa.py |
10331034
tests/unittest/_torch/speculative/test_save_state.py |
10341035
tests/unittest/_torch/speculative/test_spec_gate.py |
10351036
tests/unittest/_torch/speculative/test_torch_rejection_sampling.py |

LICENSE

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ Original Source: https://github.com/sgl-project/sglang
5656
Copyright contributors to the SGLang project
5757
Licensed under the Apache License 2.0
5858

59+
--------------------------------------------------------------------------------
60+
Suffix Automaton Speculative Decoding
61+
--------------------------------------------------------------------------------
62+
Original Source: https://github.com/basetenlabs/sa_spec
63+
Copyright 2025 Baseten
64+
Licensed under the Apache License 2.0
65+
5966
--------------------------------------------------------------------------------
6067
Text Generation Inference
6168
--------------------------------------------------------------------------------

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ state-of-the-art optimizations to perform inference efficiently on NVIDIA GPUs.<
1111
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
1212
[![cuda](https://img.shields.io/badge/cuda-13.1.0-green)](https://developer.nvidia.com/cuda-downloads)
1313
[![torch](https://img.shields.io/badge/torch-2.9.1-green)](https://pytorch.org)
14-
[![version](https://img.shields.io/badge/release-1.3.0rc6-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
14+
[![version](https://img.shields.io/badge/release-1.3.0rc7-green)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/version.py)
1515
[![license](https://img.shields.io/badge/license-Apache%202-blue)](https://github.com/NVIDIA/TensorRT-LLM/blob/main/LICENSE)
1616

1717
[Architecture](https://nvidia.github.io/TensorRT-LLM/developer-guide/overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Performance](https://nvidia.github.io/TensorRT-LLM/developer-guide/perf-overview.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](https://nvidia.github.io/TensorRT-LLM/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap)

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,6 +1216,15 @@ class BlockManager
12161216
return mWindowBlockManagers.begin()->first;
12171217
}
12181218

1219+
[[nodiscard]] SizeType32 getLastWindowSize() const
1220+
{
1221+
if (mWindowBlockManagers.empty())
1222+
{
1223+
return 0;
1224+
}
1225+
return mWindowBlockManagers.rbegin()->first;
1226+
}
1227+
12191228
[[nodiscard]] SizeType32 getNumAllocNewBlocks() const
12201229
{
12211230
return sumWindows([](auto const& manager) { return manager.getNumAllocNewBlocks(); });

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,10 +2149,13 @@ KVCacheManager::KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer
21492149
// disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case
21502150
, mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse}
21512151
{
2152+
// When num_layers < len(maxAttentionWindowVec), not all window sizes in the
2153+
// repeating pattern are used. Update mMaxAttentionWindow to the actual
2154+
// maximum window size that has been allocated in the block manager.
2155+
mMaxAttentionWindow = mBlockManager.getLastWindowSize();
2156+
21522157
TLLM_CHECK_WITH_INFO(mSinkBlockTokenLength == 0 && mSinkBubbleLength == 0,
21532158
"[kv cache manager] streamLLM is not supported at the moment");
2154-
TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow)
2155-
!= maxAttentionWindowVec.end());
21562159
// The sink tokens are stored in blocks separate from other tokens.
21572160
// If the last block of sink tokens is only partially filled,
21582161
// we fill that block with a "bubble" to reach the number of tokens per block.
Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
* Adapted from Baseten's sa_spec library (Apache-2.0)
18+
* https://github.com/basetenlabs/sa_spec
19+
*/
20+
21+
#pragma once
22+
23+
#include <cassert>
24+
#include <cstddef>
25+
#include <cstring>
26+
#include <type_traits>
27+
28+
#include "saCudaCallable.h"
29+
30+
#include "tensorrt_llm/common/config.h"
31+
32+
TRTLLM_NAMESPACE_BEGIN
33+
34+
namespace kernels::speculative_decoding::suffix_automaton
35+
{
36+
37+
/**
38+
* @brief A fixed-capacity buffer that uses external memory (pointer-based).
39+
*
40+
* This is a view into externally-managed memory. The buffer does not own
41+
* the memory and does not perform any allocation/deallocation.
42+
*
43+
* This design enables:
44+
* - Runtime-configurable capacity (no compile-time template parameter)
45+
* - Trivially copyable (can be memcpy'd between host and GPU)
46+
* - CUDA graph compatible (fixed memory addresses after initialization)
47+
*
48+
* @tparam T Element type (must be trivially copyable)
49+
* @tparam IndexT Index type (default size_t)
50+
*/
51+
template <typename T, typename IndexT = size_t>
52+
struct SABuffer
53+
{
54+
T* mData{nullptr};
55+
size_t mCapacity{0};
56+
57+
T const& at(IndexT, IndexT) const = delete;
58+
T& at(IndexT, IndexT) = delete;
59+
60+
SABuffer() = default;
61+
62+
SA_CUDA_CALLABLE void init(T* data, size_t capacity)
63+
{
64+
mData = data;
65+
mCapacity = capacity;
66+
}
67+
68+
SA_CUDA_CALLABLE T const& at(IndexT row) const
69+
{
70+
assert(static_cast<size_t>(+row) < mCapacity);
71+
return mData[+row];
72+
}
73+
74+
SA_CUDA_CALLABLE T& at(IndexT row)
75+
{
76+
assert(static_cast<size_t>(+row) < mCapacity);
77+
return mData[+row];
78+
}
79+
80+
struct Iterator
81+
{
82+
SABuffer const& buffer;
83+
IndexT index;
84+
85+
SA_CUDA_CALLABLE Iterator(SABuffer const& buffer, IndexT index)
86+
: buffer(buffer)
87+
, index(index)
88+
{
89+
}
90+
91+
SA_CUDA_CALLABLE T const& operator*() const
92+
{
93+
return buffer.at(index);
94+
}
95+
96+
SA_CUDA_CALLABLE Iterator& operator++()
97+
{
98+
index = IndexT(+index + 1);
99+
return *this;
100+
}
101+
102+
SA_CUDA_CALLABLE bool operator==(Iterator const& other) const
103+
{
104+
return index == other.index;
105+
}
106+
107+
SA_CUDA_CALLABLE bool operator!=(Iterator const& other) const
108+
{
109+
return index != other.index;
110+
}
111+
};
112+
113+
SA_CUDA_CALLABLE Iterator begin() const
114+
{
115+
return Iterator(*this, IndexT(0));
116+
}
117+
118+
SA_CUDA_CALLABLE Iterator end() const
119+
{
120+
return Iterator(*this, IndexT(mCapacity));
121+
}
122+
123+
SA_CUDA_CALLABLE size_t size() const
124+
{
125+
return mCapacity;
126+
}
127+
128+
SA_CUDA_CALLABLE size_t capacity() const
129+
{
130+
return mCapacity;
131+
}
132+
133+
void clear()
134+
{
135+
if (mData && mCapacity > 0)
136+
{
137+
memset(static_cast<void*>(mData), 0, mCapacity * sizeof(T));
138+
}
139+
}
140+
141+
SA_CUDA_CALLABLE T* data()
142+
{
143+
return mData;
144+
}
145+
146+
SA_CUDA_CALLABLE T const* data() const
147+
{
148+
return mData;
149+
}
150+
151+
/**
152+
* @brief Relocate the data pointer by a given delta.
153+
*
154+
* Used when copying between host and GPU memory to adjust pointers.
155+
*/
156+
void relocate(ptrdiff_t delta)
157+
{
158+
if (mData)
159+
{
160+
mData = reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(mData) + delta);
161+
}
162+
}
163+
164+
static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable");
165+
};
166+
167+
/**
168+
* @brief A dynamic buffer with runtime-configurable capacity using external memory.
169+
*
170+
* Like SABuffer, but tracks current length separately from capacity.
171+
* Supports push/pop operations up to the capacity limit.
172+
*
173+
* @tparam T Element type (must be trivially copyable)
174+
* @tparam IndexT Index type (default size_t)
175+
*/
176+
template <typename T, typename IndexT = size_t>
177+
struct SADynamicBuffer
178+
{
179+
T* mData{nullptr};
180+
size_t mCapacity{0};
181+
IndexT mLength{0};
182+
183+
SADynamicBuffer() = default;
184+
185+
SA_CUDA_CALLABLE void init(T* data, size_t capacity)
186+
{
187+
mData = data;
188+
mCapacity = capacity;
189+
mLength = IndexT(0);
190+
}
191+
192+
SA_CUDA_CALLABLE void clear()
193+
{
194+
mLength = IndexT(0);
195+
}
196+
197+
SA_CUDA_CALLABLE IndexT size() const
198+
{
199+
return mLength;
200+
}
201+
202+
SA_CUDA_CALLABLE size_t capacity() const
203+
{
204+
return mCapacity;
205+
}
206+
207+
SA_CUDA_CALLABLE bool empty() const
208+
{
209+
return +size() == 0;
210+
}
211+
212+
SA_CUDA_CALLABLE void extend(size_t n)
213+
{
214+
mLength = IndexT(+mLength + n);
215+
assert(static_cast<size_t>(+mLength) <= mCapacity);
216+
}
217+
218+
SA_CUDA_CALLABLE T& pushBack(T const& value)
219+
{
220+
assert(static_cast<size_t>(+mLength) < mCapacity);
221+
222+
T& result = mData[+mLength];
223+
result = value;
224+
mLength = IndexT(+mLength + 1);
225+
return result;
226+
}
227+
228+
SA_CUDA_CALLABLE T& pushBack(T&& value)
229+
{
230+
assert(static_cast<size_t>(+mLength) < mCapacity);
231+
T& result = mData[+mLength];
232+
result = std::move(value);
233+
mLength = IndexT(+mLength + 1);
234+
return result;
235+
}
236+
237+
SA_CUDA_CALLABLE T& popBack()
238+
{
239+
assert(!empty());
240+
T& result = mData[+mLength - 1];
241+
mLength = IndexT(+mLength - 1);
242+
return result;
243+
}
244+
245+
SA_CUDA_CALLABLE T const& at(IndexT row) const
246+
{
247+
assert(row < mLength);
248+
return mData[+row];
249+
}
250+
251+
SA_CUDA_CALLABLE T& at(IndexT row)
252+
{
253+
assert(row < mLength);
254+
return mData[+row];
255+
}
256+
257+
SA_CUDA_CALLABLE T* data()
258+
{
259+
return mData;
260+
}
261+
262+
SA_CUDA_CALLABLE T const* data() const
263+
{
264+
return mData;
265+
}
266+
267+
SA_CUDA_CALLABLE bool hasCapacity() const
268+
{
269+
return static_cast<size_t>(+mLength) < mCapacity;
270+
}
271+
272+
/**
273+
* @brief Relocate the data pointer by a given delta.
274+
*
275+
* Used when copying between host and GPU memory to adjust pointers.
276+
*/
277+
void relocate(ptrdiff_t delta)
278+
{
279+
if (mData)
280+
{
281+
mData = reinterpret_cast<T*>(reinterpret_cast<uint8_t*>(mData) + delta);
282+
}
283+
}
284+
285+
static_assert(std::is_trivially_copyable<T>::value, "T must be trivially copyable");
286+
};
287+
288+
// Verify that our buffer types are trivially copyable (required for GPU memcpy)
289+
static_assert(std::is_trivially_copyable<SABuffer<int>>::value, "SABuffer must be trivially copyable");
290+
static_assert(std::is_trivially_copyable<SADynamicBuffer<int>>::value, "SADynamicBuffer must be trivially copyable");
291+
292+
} // namespace kernels::speculative_decoding::suffix_automaton
293+
294+
TRTLLM_NAMESPACE_END
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
* SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*
17+
* Adapted from Baseten's sa_spec library (Apache-2.0)
18+
* https://github.com/basetenlabs/sa_spec
19+
*/
20+
21+
#pragma once
22+
23+
#ifdef __CUDACC__
24+
#include <cuda_runtime.h>
25+
#define SA_CUDA_CALLABLE __host__ __device__ __forceinline__
26+
#else
27+
#define SA_CUDA_CALLABLE
28+
// Provide a placeholder type for cudaStream_t when not compiling with CUDA.
29+
// Only define if not already defined to avoid conflicts with cuda_runtime_api.h.
30+
#if !defined(cudaStream_t)
31+
#define cudaStream_t int
32+
#endif
33+
#endif // __CUDACC__

0 commit comments

Comments
 (0)