Skip to content

Commit 0c76830

Browse files
Support for flash attention (#7)
* Supprt for flash attention Signed-off-by: Jinseok Lee <jindol21@rebellions.ai>
1 parent 669da6a commit 0c76830

5 files changed

Lines changed: 265 additions & 29 deletions

File tree

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# ruff: noqa: E501
16+
17+
from vllm import LLM, SamplingParams
18+
19+
prompts = [
20+
"""
21+
Rebellions, SK Telecom, and DOCOMO Innovations Partner to Accelerate Next-Gen AI Infrastructure
22+
News
23+
Apr 24, 2025
24+
[Seoul, April 24, 2025] Rebellions today announced the signing of a strategic MOU with SK Telecom and DOCOMO Innovations (DII), a subsidiary of Japan’s leading telecom provider NTT DOCOMO. The agreement lays the foundation for joint development and validation of cutting-edge AI acceleration technologies.
25+
26+
Under this collaboration, the three companies will focus on evaluating Rebellions’ ATOM-based NPU servers within SK Telecom’s NPU farm. Building on the success of this initiative, the partners plan to expand validation efforts to include a broader range of Rebellions’ product portfolio.
27+
28+
Rebellions will bring its high-performance AI chips, proven system reliability, and software optimization expertise to the table. SK Telecom will provide its NPU farm and AI infrastructure capabilities, while DII will conduct technical evaluations and help facilitate discussions on potential paths to commercialization. This partnership is expected to combine the core strengths of each company to help accelerate the growth of the global AI hardware ecosystem.
29+
30+
Rebellions recently established a Japanese subsidiary, further accelerating its global expansion. This agreement marks a significant step in the company’s ambition to become a key player in the global AI infrastructure market.
31+
32+
“Rebellions is committed to proving the real-world stability and performance of our technology in live infrastructure environments,” said Jinwook Oh, CTO of Rebellions. “Collaboration between a hardware provider, an infrastructure leader, and actual users is crucial – and together with SK Telecom and DII, we’re building meaningful progress toward the future of AI.”
33+
34+
Yoshikazu Akinaga, CEO of DOCOMO Innovations, added: “At DOCOMO Innovations, we are dedicated to driving forward innovation in practical AI solutions by working closely with global technology leaders. This partnership—bringing together Rebellions’ advanced semiconductor technology and SK Telecom’s infrastructure capabilities—will allow us to explore and assess the potential of scalable, sustainable AI systems, while maintaining a technology-agnostic approach to ensure optimal solutions for future applications.”
35+
36+
Sangmin Lee, Vice President of Growth Business Development Office at SK Telecom, commented: “SK Telecom delivers world-class, AI-optimized cloud services. This collaboration provides an opportunity to demonstrate our NPU cloud technology, which integrates a wide range of AI data center solutions. We are committed to contributing to the success of next-generation AI infrastructure services—powered not just by GPUs, but also by NPUs.”
37+
""", """
38+
Rebellions Partners on Strategic Collaboration Initiative to Advance Global AI Data Center Ecosystem
39+
News
40+
Mar 04, 2025
41+
[Seoul, March 4, 2025] Rebellions, a pioneering AI chip company, today announced a strategic partnership with Penguin Solutions and SK Telecom at the Mobile World Congress (MWC) 2025 in Barcelona, Spain, taking a significant step toward building a global AI data center ecosystem.
42+
43+
The collaboration aims to establish technical foundations and business capabilities for large-scale data center operations. By combining the unique strengths and experiences of the Rebellions, Penguin Solutions and SK Telecom, the companies will pursue strategic development around AI inference and software stack delivery in the AI data center sector.
44+
45+
Within the planned collaboration, Rebellions brings its portfolio of energy-efficient AI accelerators optimized for Generative AI workloads. Penguin Solutions, a premiere AI infrastructure expert with more than 85,000 deployed GPUs under management, brings deep AI infrastructure expertise to the partnership. SK Telecom, actively accelerating its AI infrastructure business with key software elements and major investments in AI infra companies around the world, completes the strategic alliance.
46+
47+
The three companies will collaborate on multiple fronts, including:
48+
49+
Developing AI infrastructure solutions and creating testing environments for enterprise clients
50+
Joint development of AI data center management solutions by integrating Rebellions’ AI accelerators while supporting both GPU and NPU environments
51+
Leveraging each party’s technical expertise for software development specialized in AI data center infrastructure
52+
“Since ‘DeepSeek’, efficient operations have emerged as a key concept of an AI business, making energy efficiency and cost of ownership critical evaluation criteria for customers,” said Sunghyun Park, CEO of Rebellions. “This partnership represents a crucial first step in establishing an efficient AI data center ecosystem by bringing together companies with diverse technological expertise.”
53+
54+
Mark Seamans, Penguin Solutions’ Vice President of Global Marketing, stated “With this partnership, our deep expertise in HPC and AI cluster management will now extend beyond GPU infrastructure to NPU environments. We’re committed to providing state-of-the-art AI infrastructure that meets the diverse needs of customers, solves for complexity, and accelerates business outcomes in the rapidly-growing global AI market.”
55+
56+
Through this strategic partnership, Rebellions, Penguin Solutions, and SK Telecom aim to identify and expand new business opportunities in the global AI data center market, positioning themselves at the forefront of AI infrastructure innovation.
57+
"""
58+
]
59+
60+
qwen3_0_6_model_id = "Qwen/Qwen3-0.6B"
61+
qwen3_1_7_model_id = "Qwen/Qwen3-1.7B"
62+
qwen3_4_model_id = "Qwen/Qwen3-4B"
63+
qwen3_8_model_id = "Qwen/Qwen3-8B"
64+
qwen3_30_moe_model_id = "Qwen/Qwen3-30B-A3B"
65+
qwen1_5_moe_model_id = "Qwen/Qwen1.5-MoE-A2.7B"
66+
67+
# Create a sampling params object.
68+
sampling_params = SamplingParams(temperature=0.0, max_tokens=256)
69+
llm = LLM(
70+
model=qwen3_4_model_id,
71+
max_model_len=16 * 128,
72+
block_size=128,
73+
enable_chunked_prefill=True,
74+
max_num_batched_tokens=128,
75+
max_num_seqs=5,
76+
)
77+
78+
# Generate texts from the prompts. The output is a list of RequestOutput objects
79+
# that contain the prompt, generated text, and other information.
80+
outputs = llm.generate(prompts, sampling_params)
81+
# Print the outputs.
82+
for output in outputs:
83+
prompt = output.prompt
84+
generated_text = output.outputs[0].text
85+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

vllm_rbln/attention/backends/flash_attention.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def get_kv_cache_shape(
204204
kv_cache_shape= [B, H, 1, S, D]
205205
query_shape = [1, H, G, L, D]
206206
"""
207-
return (2, num_blocks, num_kv_heads, 1, block_size, head_size)
207+
# for partition skip, we need dummy block slot.
208+
no_dummy_slots = 1
209+
return (2, num_blocks + no_dummy_slots, num_kv_heads, 1, block_size,
210+
head_size)
208211

209212
@staticmethod
210213
def swap_blocks(
@@ -235,10 +238,8 @@ def __init__(self, input_builder: ModelInputForRebelBuilder) -> None:
235238
self.chunked_prefill = input_builder.chunked_prefill
236239
self.chunked_prefill_size = input_builder.chunked_prefill_size
237240
self.input_builder = input_builder
238-
# model max sequence length (cache_config.num_cpu_blocks)
239-
self.max_seq_len = 128 * 1024
240-
# flash attention partition size (cache_config.block_size)
241-
self.partition_len = 1024
241+
242+
self.partition_len = input_builder.block_size
242243

243244
def prepare(self):
244245
self.input_data = self.input_builder.input_data
@@ -262,8 +263,13 @@ def build(
262263
steps = [[input_positions[0]]
263264
for input_positions in input_data.input_positions]
264265
seq_idx = torch.tensor(steps, dtype=torch.int32)
265-
max_seq_len = self.max_seq_len
266266
partition_len = self.partition_len
267+
# no. of block(HW constraint) determines max sequence length.
268+
# max_model_len(Model constraint) determines max sequence length.
269+
# One of them is selected for max_seq_len.
270+
block_length = self.input_builder.runner.cache_config.num_gpu_blocks * \
271+
partition_len
272+
max_seq_len = min(self.input_builder.max_model_len, block_length)
267273
num_partition = max_seq_len // partition_len
268274

269275
batch_size = 1 if input_data.num_prefills else len(steps)
@@ -298,7 +304,7 @@ def build(
298304
1,
299305
1,
300306
prefill_chunk_size,
301-
self.max_seq_len,
307+
max_seq_len,
302308
dtype=torch.float32)
303309
causal_mask = 1 - torch.triu(torch.ones(1, 1, prefill_chunk_size,
304310
prefill_chunk_size),
@@ -313,7 +319,7 @@ def build(
313319
1,
314320
1,
315321
1,
316-
self.max_seq_len,
322+
max_seq_len,
317323
dtype=torch.float32)
318324
for batch_index, batch_step in enumerate(steps):
319325
decode_attention_mask[batch_index, :, :, :, :batch_step[0] +

vllm_rbln/worker/model_runner.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ def __init__(self,
165165
self.sliding_window = self.runner.sliding_window
166166
self.block_size = self.runner.cache_config.block_size
167167
self.device = self.runner.device
168+
self.max_model_len = self.runner.scheduler_config.max_model_len
169+
168170
if self.runner.attn_backend is not None:
169171
# spec decode (e.g. Medusa) does not have atten backend
170172
attn_backend = self.runner.attn_backend
@@ -190,7 +192,7 @@ def _prepare_prompt(
190192
seq_group_metadata_list: List[SequenceGroupMetadata],
191193
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
192194
assert len(seq_group_metadata_list) > 0
193-
input_block_ids: List[int] = []
195+
list_input_block_ids: List[List[int]] = []
194196

195197
block_size = self.runner.block_size
196198
assert (
@@ -212,8 +214,9 @@ def _prepare_prompt(
212214

213215
assert seq_group_metadata.block_tables is not None
214216
block_table = seq_group_metadata.block_tables[seq_id]
215-
assert len(block_table) == 1
216-
input_block_ids.append(block_table[0])
217+
assert len(block_table) == math.ceil(seq_data.get_len() /
218+
block_size)
219+
list_input_block_ids.append(block_table)
217220
data.input_tokens.append(tokens)
218221
data.input_positions.append(list(range(computed_len, seq_len)))
219222
data.num_prefills += 1
@@ -229,6 +232,21 @@ def _prepare_prompt(
229232
max_seq_len = max(data.seq_lens)
230233
assert max_seq_len > 0
231234

235+
num_partition = self.max_model_len // block_size
236+
dummy = self.runner.cache_config.num_gpu_blocks
237+
# make_tensor_with_pad takes List[List[]] as input
238+
# To make it work, input_block_ids is expanded
239+
input_block_ids = make_tensor_with_pad(list_input_block_ids,
240+
max_len=num_partition,
241+
pad=dummy,
242+
dtype=torch.long,
243+
device=self.device)
244+
# input_block_ids gets back in here.
245+
input_block_ids = input_block_ids.flatten().tolist()
246+
input_block_ids = torch.tensor(input_block_ids,
247+
dtype=torch.long,
248+
device=self.device)
249+
232250
prefill_size = (self.chunked_prefill_size if self.chunked_prefill else
233251
1 << (math.ceil(math.log2(max_seq_len))))
234252
input_tokens = make_tensor_with_pad(data.input_tokens,
@@ -241,9 +259,6 @@ def _prepare_prompt(
241259
pad=0,
242260
dtype=torch.long,
243261
device=self.device)
244-
input_block_ids = torch.tensor(input_block_ids,
245-
dtype=torch.long,
246-
device=self.device)
247262

248263
logger.info("[RBLN] model input builder, prepare_prompt")
249264
logger.info("\tpadded input_tokens = %s", input_tokens)
@@ -260,7 +275,7 @@ def _prepare_decode(
260275
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
261276
assert len(seq_group_metadata_list) > 0
262277

263-
arr_input_block_ids: List[int] = []
278+
list_input_block_ids: List[List[int]] = []
264279
block_size = self.block_size
265280
for seq_group_metadata in seq_group_metadata_list:
266281
assert not seq_group_metadata.is_prompt
@@ -275,10 +290,8 @@ def _prepare_decode(
275290
assert seq_group_metadata.block_tables is not None
276291
block_table = seq_group_metadata.block_tables[seq_id]
277292
assert len(block_table) >= 1
278-
for i in range(len(block_table)):
279-
assert block_table[i] != self.max_num_seqs
280-
arr_input_block_ids.append(block_table[i])
281293

294+
list_input_block_ids.append(block_table)
282295
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
283296
data.input_tokens.append([generation_token])
284297
data.input_positions.append([token_position])
@@ -291,10 +304,18 @@ def _prepare_decode(
291304
data.slot_mapping.append(block_offset)
292305

293306
# batch padding
294-
batch_padding_szie = self.max_num_seqs - len(data.input_tokens)
295-
data.input_tokens.extend([[0]] * batch_padding_szie)
296-
data.input_positions.extend([[0]] * batch_padding_szie)
297-
arr_input_block_ids.extend([self.max_num_seqs] * batch_padding_szie)
307+
dummy = self.runner.cache_config.num_gpu_blocks
308+
batch_padding_size = self.max_num_seqs - len(data.input_tokens)
309+
data.input_tokens.extend([[0]] * batch_padding_size)
310+
data.input_positions.extend([[0]] * batch_padding_size)
311+
list_input_block_ids.extend([[dummy]] * batch_padding_size)
312+
313+
num_partition = self.max_model_len // block_size
314+
input_block_ids = make_tensor_with_pad(list_input_block_ids,
315+
max_len=num_partition,
316+
pad=dummy,
317+
dtype=torch.long,
318+
device=self.device)
298319

299320
input_tokens = make_tensor_with_pad(data.input_tokens,
300321
max_len=1,
@@ -306,9 +327,6 @@ def _prepare_decode(
306327
pad=0,
307328
dtype=torch.long,
308329
device=self.device)
309-
input_block_ids = torch.tensor(arr_input_block_ids,
310-
dtype=torch.long,
311-
device=self.device)
312330

313331
logger.info("[RBLN] model input builder, prepare_decode")
314332
logger.info("\tpadded input_tokens = %s", data.input_tokens)

vllm_rbln/worker/utils.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Rebellions Inc. All rights reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at:
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""A RBLN util class."""
15+
16+
import math
17+
from typing import Optional
18+
19+
20+
def get_maximum_num_blocks(
21+
config,
22+
tensor_parallel_size: int,
23+
kvcache_block_size: int,
24+
nbits_per_param: Optional[int] = None,
25+
n_model_params: Optional[int] = None,
26+
kernel_size: Optional[int] = None,
27+
buffer: Optional[int] = None,
28+
num_runtimes: int = 2,
29+
) -> int:
30+
# We are finding max_n_blocks(x) that satisfies the following equation:
31+
32+
# available_dram - kernel_size - buffer
33+
# - num_layers * 2 * tensor_parallel_size
34+
# * align_2MB(
35+
# x
36+
# * block_size
37+
# * align_64(head_dim)
38+
# * math.ceil(num_key_value_heads / tensor_parallel_size)
39+
# * 2
40+
# ) > 0
41+
42+
# This inequality can be rewritten as follows:
43+
44+
# a - c * align_2MB(b * x) > 0
45+
# where
46+
# a = available_dram - kernel_size - buffer
47+
# b = block_size
48+
# * align_64(head_dim)
49+
# * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
50+
# c = num_layers * 2 * tensor_parallel_size
51+
52+
# We can rewrite the inequality as follows:
53+
# k > align_2MB(b*x)
54+
# where
55+
# k = a / c
56+
57+
# After that, we can derive the following equation:
58+
# x = floor(2**21 / b * floor((k - 1) / 2**21))
59+
60+
def align(x: int, nbytes: int) -> int:
61+
return int(math.ceil(x / nbytes) * nbytes)
62+
63+
def align_2MB(x: int) -> int:
64+
return align(x, 2**21)
65+
66+
num_layers = config.hf_config.num_hidden_layers
67+
head_dim = config.hf_config.head_dim
68+
vocab_size = config.hf_config.vocab_size
69+
hidden_size = config.hf_config.hidden_size
70+
num_key_value_heads = config.hf_config.num_key_value_heads
71+
72+
# TODO(jongho): Update if target npu is REBEL.
73+
ATOM_DRAM_NBYTES = 16 * 2**30
74+
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
75+
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES -
76+
ATOM_SYS_DRAM_NBYTES)
77+
78+
if kernel_size is None:
79+
if n_model_params is None:
80+
raise ValueError("`n_model_params` should be specified \
81+
to estimate the kernel memory.")
82+
# Get estimated kernel size (approximated)
83+
lm_heads_params = align(vocab_size, 64) * hidden_size
84+
lm_heads_nbytes = (align_2MB(
85+
lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) *
86+
tensor_parallel_size)
87+
params = n_model_params - lm_heads_params
88+
layer_nbytes = (align_2MB(params * nbits_per_param // 8 / num_layers /
89+
tensor_parallel_size) * num_layers *
90+
tensor_parallel_size)
91+
kernel_size = layer_nbytes + lm_heads_nbytes
92+
elif n_model_params is not None:
93+
raise ValueError(
94+
"Both `n_model_params` and `kernel_size` cannot be specified.")
95+
96+
available_dram -= kernel_size
97+
98+
if buffer is None:
99+
# TODO: Accurate buffer estimation
100+
buffer_per_runtime_per_core = 2**28 # 256MB per runtime
101+
# 1 for prefill, 1 for decoder
102+
buffer_per_core = buffer_per_runtime_per_core * num_runtimes
103+
buffer = buffer_per_core * tensor_parallel_size
104+
available_dram -= buffer
105+
106+
b = kvcache_block_size * align(head_dim, 64) * math.ceil(
107+
num_key_value_heads / tensor_parallel_size) * 2
108+
c = num_layers * 2 * tensor_parallel_size
109+
k = available_dram / c
110+
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
111+
112+
return max_n_blocks

0 commit comments

Comments
 (0)