Skip to content

Commit d18b3ef

Browse files
authored
v0.1.2-alpha.1 release, featuring MTP. (#16)
Release TileRT v0.1.2-alpha.1 with initial support for Multi-Token Prediction (MTP). With mtp=3, decoding reaches up to 590 tokens/s on synthetic workloads and ~440 tokens/s on real generation tasks.
1 parent 20a862c commit d18b3ef

File tree

14 files changed

+1592
-715
lines changed

14 files changed

+1592
-715
lines changed

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ jobs:
3636
- name: Install lint dependencies
3737
run: |
3838
python -m pip install --upgrade pip
39-
pip install --no-cache-dir -r requirements-ci.txt
39+
pip install --no-cache-dir -r requirements-dev.txt
4040
- name: Run all linting checks
4141
run: ./scripts/lint.sh

README.md

Lines changed: 105 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,37 @@
66
<a href="https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-1E90FF"></a>
77
</p>
88
<p>
9-
<a href="#python-package-installation"><b>Installation</b></a> |
10-
<a href="#getting-started"><b>Getting Started</b></a>
9+
<a href="#overview"><b>Overview</b></a> ·
10+
<a href="#running-the-generation-example"><b>Generation</b></a> ·
11+
<a href="#running-the-generation-example-with-multi-token-prediction-mtp"><b>MTP Generation</b></a> ·
12+
<a href="#installation"><b>Installation</b></a> ·
13+
<a href="#news"><b>News</b></a>
1114
</p>
1215
</div>
1316

14-
## News
17+
______________________________________________________________________
1518

16-
- **\[2025-12-23\]****[v0.1.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.1)** — Achieved ~35% reduction in end-to-end token generation latency on a single node with 8× NVIDIA B200. See our latest benchmarks for detailed measurements.
19+
<a id="news"></a>
1720

18-
- **\[2025-11-20\]** 🚀 **[v0.1.0-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.0-alpha.1)** — Initial release of TileRT for DeepSeek-V3.2-Exp, designed for **ultra-low-latency** inference. Available on [PyPI](https://pypi.org/project/tilert) and [HuggingFace](https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT).
21+
## 📰 News
22+
23+
- :fire: **2026-01-26 · [v0.1.2-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.2-alpha.1)**. **Multi-Token Prediction (MTP) lands in TileRT**. With mtp=3, we observe decoding rates up to **590 tokens/s** under synthetic workloads.
24+
25+
-**2025-12-23 · [v0.1.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.1)**. Achieved ~**35% further reduction** (3 ~ 4x speedup over baseline) in end-to-end token generation latency on a single node with **8× NVIDIA B200**.
26+
27+
- 🚀 **2025-11-20 · [v0.1.0-alpha.1](https://github.com/tile-ai/TileRT/releases/tag/v0.1.0-alpha.1)**. Initial public release for **DeepSeek-V3.2-Exp**, targeting **ultra-low-latency** inference. Available on [PyPI](https://pypi.org/project/tilert) and [HuggingFace](https://huggingface.co/Tile-AI/DeepSeek-V3.2-Exp-TileRT).
28+
29+
______________________________________________________________________
30+
31+
<a id="overview"></a>
1932

2033
## TileRT: Pushing LLM Latency to the Limit
2134

2235
TileRT is an experimental project exploring core compiler techniques for serving large language models (LLMs) in **ultra-low-latency** scenarios. Its goal is to push the latency limits of LLMs without compromising model size or quality—for example, enabling models with hundreds of billions of parameters to achieve millisecond-level **time per output token (TPOT)**.
2336

2437
<p align="center">
2538
<img src="assets/generate.gif" alt="TileRT Benchmark"><br>
26-
Figure 1. Sequence generation with TileRT.
39+
Figure 1. Sequence generation with TileRT, now enhanced with Multi-Token Prediction (MTP) to accelerate inference.
2740
</p>
2841

2942
We evaluated TileRT’s preliminary performance using the [**DeepSeek-V3.2-Exp**](https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp) model (without lossy optimizations such as quantization or distillation) with a batch size of 1 on 8× NVIDIA B200 GPUs. As shown in the benchmark below, TileRT demonstrates substantial improvements over existing inference systems.
@@ -39,6 +52,8 @@ To achieve this, TileRT introduces a **tile-level runtime engine**. Leveraging a
3952

4053
The project is actively evolving, and the underlying compiler techniques will be gradually shared with the community as they are integrated into **TileLang** and **TileScale**.
4154

55+
______________________________________________________________________
56+
4257
## Installation
4358

4459
- [Prerequisites](#prerequisites)
@@ -145,39 +160,112 @@ docker run --gpus all -it \
145160
tilert:v0.1.0
146161
```
147162

148-
Once inside the container, you can run the following Python script:
163+
Once inside the container, run the following Python script to perform text generation:
149164

150165
```python
151166
from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
152167

153168
generator: ShowHandsGenerator = ShowHandsGenerator(
154169
max_new_tokens=1000,
155170
model_weights_dir=MODEL_WEIGHTS_DIR,
171+
with_mtp=False, # Disable MTP
156172
)
157173
generator.from_pretrained()
158174

159-
prompt = """Tell me three jokes:
160-
161-
1. A dad joke,
162-
2. A programmer joke,
163-
3. A joke that only makes sense if you've ever tried to train a large language model.
164-
Keep each joke under 15 words.
165-
"""
175+
prompt = (
176+
"Tell me three jokes:\n\n"
177+
"1. A dad joke,\n"
178+
"2. A programmer joke,\n"
179+
"3. A joke that only makes sense if you've ever tried "
180+
"to train a large language model.\n"
181+
"Keep each joke under 15 words."
182+
)
166183

167184
print("Prompt:", prompt)
168185
print("Completion:")
169-
completion: generator.generate(prompt)
186+
completion = generator.generate(prompt)
170187
```
171188

172-
For instance, using the above prompt, TileRT might generate:
189+
For example, TileRT may generate:
190+
191+
<details>
192+
<summary><b>Sample output (click to expand)</b></summary>
173193

174194
```text
175195
1. I'm afraid for the calendar. Its days are numbered.
176196
2. There are only 10 kinds of people: those who understand binary and those who don't.
177197
3. My model's loss is low, but its answers are still nonsense. Overfitting.
178198
```
179199

180-
This example gives you a quick idea of the type of output you can expect from the precompiled model.
200+
</details>
201+
202+
This example demonstrates basic single-step autoregressive generation using the precompiled model.
203+
204+
### Running the Generation Example with Multi-Token Prediction (MTP)
205+
206+
> \[!IMPORTANT\]
207+
> **Weights update required for MTP.** Multi-Token Prediction (MTP) introduces additional **MTP heads** in the model weights.
208+
> If you were using TileRT **before v0.1.1**, please make sure you download the **latest weights** from Hugging Face.
209+
> Older weights do not include the required MTP heads and will fail to run when MTP is enabled.
210+
211+
TileRT also supports Multi-Token Prediction (MTP), which allows the model to generate multiple tokens per forward pass and reduces sequential decoding depth.
212+
213+
To better illustrate MTP behavior, we use a longer prompt that encourages extended generation:
214+
215+
```python
216+
from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
217+
218+
generator: ShowHandsGenerator = ShowHandsGenerator(
219+
max_new_tokens=1000,
220+
model_weights_dir=MODEL_WEIGHTS_DIR,
221+
with_mtp=True, # Enable MTP
222+
)
223+
generator.from_pretrained()
224+
prompt = "Tell me 10 jokes, keep them all under 100 words."
225+
226+
print("Prompt:", prompt)
227+
print("Completion:")
228+
completion = generator.generate(prompt)
229+
```
230+
231+
When MTP is enabled, TileRT may report statistics similar to the following during generation:
232+
233+
```text
234+
Accepted length: mean=2.77, min=1, max=4
235+
```
236+
237+
This indicates that, on average, multiple tokens are accepted per decoding step under MTP.
238+
239+
<details>
240+
<summary><b>Sample output (click to expand)</b></summary>
241+
242+
```text
243+
Of course! Here are 10 short jokes for you.
244+
245+
1. I told my wife she was drawing her eyebrows too high. She looked surprised.
246+
247+
2. I invented a new word: Plagiarism.
248+
249+
3. Why don't scientists trust atoms? Because they make up everything.
250+
251+
4. I'm reading a book on anti-gravity. It's impossible to put down.
252+
253+
5. What's the best thing about Switzerland? I don't know, but the flag is a big plus.
254+
255+
6. I told my computer I needed a break, and now it won't stop sending me vacation ads.
256+
257+
7. Why did the scarecrow win an award? He was outstanding in his field.
258+
259+
8. What do you call a fake noodle? An impasta.
260+
261+
9. I told my suitcase there's no vacation, and now it has a lot of baggage.
262+
263+
10. Why don't skeletons fight each other? They don't have the guts.
264+
```
265+
266+
</details>
267+
268+
This example highlights how MTP enables TileRT to efficiently generate longer outputs by accepting multiple tokens per decoding step, while preserving the same Python API interface.
181269

182270
For more details, please refer to the [generation script](https://github.com/tile-ai/TileRT/blob/main/python/generate.py).
183271

assets/generate.gif

-1.19 MB
Loading

python/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def _load_library(filename: str) -> Any:
4040
lib_path = Path(__file__).parent / filename
4141

4242
try:
43-
return ctypes.CDLL(str(lib_path))
43+
torch.ops.load_library(str(lib_path))
44+
return lib_path
4445
except Exception as e:
4546
raise RuntimeError(f"Failed to load library from {lib_path}") from e
4647

python/generate.py

Lines changed: 88 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Text generation script for TileRT."""
22

33
from argparse import ArgumentParser
4+
from typing import cast
5+
6+
import numpy as np
47

58
from tilert.models.deepseek_v3_2.dsa_show_hands import ShowHandsGenerator
69

@@ -16,7 +19,16 @@ def parse_args(): # type: ignore
1619
parser.add_argument("--max-new-tokens", type=int, default=4000, help="Max tokens to generate")
1720
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
1821
parser.add_argument("--interactive", action="store_true")
19-
parser.add_argument("--fp8", action="store_true")
22+
parser.add_argument(
23+
"--with-mtp",
24+
action="store_true",
25+
help="Enable MTP (Multi-Token Prediction) for speculative decoding",
26+
)
27+
parser.add_argument(
28+
"--use-random-weights",
29+
action="store_true",
30+
help="Use random weights instead of pretrained (for testing MTP without real weights)",
31+
)
2032
return parser.parse_args()
2133

2234

@@ -25,22 +37,31 @@ def parse_args(): # type: ignore
2537
usage:
2638
execute below command under tilert root directory:
2739
40+
# Standard generation with pretrained weights:
2841
python python/generate.py --model-weights-dir "xxxx" 2>&1 | tee test.log
42+
43+
# MTP generation with random weights (for testing):
44+
python python/generate.py --model-weights-dir "xxxx" --with-mtp \
45+
--use-random-weights 2>&1 | tee test.log
46+
47+
# MTP generation with pretrained weights (when available):
48+
python python/generate.py --model-weights-dir "xxxx" --with-mtp 2>&1 | tee test.log
2949
"""
3050
args = parse_args()
3151

3252
generator: ShowHandsGenerator = ShowHandsGenerator(
3353
max_new_tokens=args.max_new_tokens,
3454
temperature=args.temperature,
3555
model_weights_dir=args.model_weights_dir,
36-
enable_fp8_ops=args.fp8,
56+
with_mtp=args.with_mtp,
3757
)
3858

39-
# uncomment to use random weights
40-
# generator.init_random_weights()
41-
42-
# use pretrained weights
43-
generator.from_pretrained()
59+
if args.use_random_weights:
60+
print("Initializing with random weights...")
61+
generator.init_random_weights()
62+
else:
63+
print("Loading pretrained weights...")
64+
generator.from_pretrained()
4465

4566
# simple memoryless interactive mode
4667
if args.interactive:
@@ -53,14 +74,70 @@ def parse_args(): # type: ignore
5374
else:
5475
# This prompt is to test the model’s ability to follow instructions
5576
# (in terms of quantity, type, and length) while keeping it fun.
77+
print("==== Performance ====")
5678
prompt = "Tell me 10 jokes, keep them all under 100 words."
57-
5879
print("Prompt:", prompt)
59-
print("Completion:")
60-
completion: str = generator.generate(prompt) # type: ignore[has-type]
80+
all_times = []
81+
all_accepted = []
82+
for _iter in range(20):
83+
if _iter % 5 == 0:
84+
print(f"Executing iter {_iter}...")
85+
results, time_list, accepted_counts = cast(
86+
tuple[str, list[float], list[int]],
87+
generator.generate(prompt, False), # type: ignore[has-type]
88+
)
89+
all_times.append(time_list)
90+
all_accepted.append(accepted_counts)
91+
92+
if args.with_mtp:
93+
for token_num in range(100, 300, 100):
94+
times_to_token_num = []
95+
for time_list, accepted_list in zip(all_times, all_accepted):
96+
if len(time_list) > 5 and len(accepted_list) > 5:
97+
times = time_list[5:]
98+
accepted = accepted_list[5:]
99+
cumsum_tokens = np.cumsum(accepted)
100+
cumsum_times = np.cumsum(times)
101+
# Find index where we reach token_num tokens
102+
idx = np.searchsorted(cumsum_tokens, token_num)
103+
if idx < len(cumsum_times):
104+
times_to_token_num.append(cumsum_times[idx])
105+
if times_to_token_num:
106+
mean_total_time = np.mean(times_to_token_num)
107+
mean_time = mean_total_time / token_num
108+
speed = 1 / mean_time
109+
out_str = (
110+
f"**Perf@{token_num}: {speed:.3f} tokens/s & "
111+
f"{(mean_time * 1000):.3f} ms**"
112+
)
113+
print(out_str)
114+
115+
# Print accepted tokens statistics
116+
flat_accepted = [a for accepted_list in all_accepted for a in accepted_list]
117+
if flat_accepted:
118+
avg_accepted = sum(flat_accepted) / len(flat_accepted)
119+
min_accepted = min(flat_accepted)
120+
max_accepted = max(flat_accepted)
121+
print(
122+
f"**Accepted length: mean={avg_accepted:.2f}, "
123+
f"min={min_accepted}, max={max_accepted}**"
124+
)
125+
else:
126+
all_times_np = np.array(all_times)
127+
for token_num in range(100, 300, 100):
128+
mean_time = np.mean(all_times_np[..., 5:token_num])
129+
speed = 1 / mean_time
130+
out_str = (
131+
f"**Perf@{token_num}: {speed:.3f} tokens/s & {(mean_time * 1000):.3f} ms**"
132+
)
133+
print(out_str)
134+
print(results)
61135

62136
# This prompt is used to test long sequence generation
63137
prompt = "Hi, can you tell me a very long story, with roughly 3000 words?"
64138
print("Prompt:", prompt)
65139
print("Completion:")
66-
completion = generator.generate(prompt) # type: ignore[has-type]
140+
completion, _, _ = generator.generate(prompt) # type: ignore[has-type]
141+
142+
print("Cleaning up...")
143+
generator.cleanup()

python/models/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from tilert import logger
1111
from tilert.models.deepseek_config import get_rank, get_world_size
12+
from tilert.models.deepseek_v3_2.params import BaseParams
1213
from tilert.models.preprocess import WeightLoader
1314
from tilert.utils import get_profile_log_tensor
1415

@@ -52,9 +53,10 @@ def __init__(
5253

5354
self.flag_enable_tilert = False
5455

55-
if compute_kernel_type not in ["bf16", "fp8"]:
56+
if compute_kernel_type not in ["bf16", "fp8", "fp8mma"]:
5657
raise ValueError(
57-
f"Invalid compute kernel type: {compute_kernel_type}, must be one of bf16, fp8."
58+
f"Invalid compute kernel type: {compute_kernel_type}, \
59+
must be one of bf16, fp8, fp8mma."
5860
)
5961
self.compute_kernel_type = compute_kernel_type
6062

@@ -215,7 +217,7 @@ def tilert_forward(self, *args: Any, **kwargs: Any) -> Any: # noqa: U100
215217
raise NotImplementedError("Tilert forward not implemented")
216218

217219
@abstractmethod
218-
def to_tilert_weights(self, *args: Any, **kwargs: Any) -> None:
220+
def to_tilert_weights(self, *args: Any, **kwargs: Any) -> BaseParams | None:
219221
"""Convert weights to tilert.
220222
221223
Args:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""DeepSeek v3.2 model package."""

0 commit comments

Comments
 (0)