Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions vllm_rbln/attention/backends/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,6 @@ def build(
if not envs.RBLN_FLASH_CAUSAL_ATTN:
if input_data.num_prefills:
step = steps[0][0]
assert input_data.num_prefills == 1
prefill_chunk_size = (
self.chunked_prefill_size if self.chunked_prefill else 1 <<
(math.ceil(math.log2(input_data.seq_lens[0]))))
Expand All @@ -414,13 +413,19 @@ def build(
max_seq_len,
dtype=torch.float16
if self.enforce_eager else torch.float32)
causal_mask = 1 - torch.triu(torch.ones(
1, 1, prefill_chunk_size, prefill_chunk_size),
diagonal=1)
if step >= prefill_chunk_size:
chunked_attention_mask[:, :, :, :, :step] = 1
valid_len = sum(query_lens)
causal_masks = [
torch.tril(torch.ones(query_len, query_len))
for query_len in query_lens
]
causal_mask = torch.block_diag(*causal_masks)
cur_causal_mask = torch.zeros(1, 1, prefill_chunk_size,
prefill_chunk_size)
cur_causal_mask[:, :, :valid_len, :valid_len] = causal_mask
if step > 0:
chunked_attention_mask[:, :, :, :query_lens[0], :step] = 1
chunked_attention_mask[:, :, :, :, step:step +
prefill_chunk_size] = causal_mask
prefill_chunk_size] = cur_causal_mask
attn_masks = chunked_attention_mask
else:
decode_attention_mask = torch.zeros(
Expand Down
10 changes: 7 additions & 3 deletions vllm_rbln/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SchedulerSwappedInOutputs, SchedulingBudget)
from vllm.sequence import SequenceGroup, SequenceStage, SequenceStatus

import vllm_rbln.rbln_envs as envs
from vllm_rbln.logger import init_logger

logger = init_logger(__name__)
Expand Down Expand Up @@ -388,9 +389,12 @@ def _schedule_prefills(
)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)

# NOTE(RBLN):
# For rbln target, we only consider batch size of 1 for prefill.
break
if not enable_chunking or envs.RBLN_FLASH_CAUSAL_ATTN:
# NOTE(RBLN):
# For rbln target, we only consider batch size of 1 for prefill.
# In case of chunked prefill, it's available to schedule
# multiple requests in a single batch.
break

logger.debug("waiting_queue -> len=%s", len(waiting_queue))
# Queue requests that couldn't be scheduled.
Expand Down
31 changes: 18 additions & 13 deletions vllm_rbln/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,8 @@ def _prepare_prompt(
num_blocks_per_ve = num_blocks // \
self.runner.parallel_config.pipeline_parallel_size
ve_offset = num_blocks_per_ve * virtual_engine
assert (
len(seq_group_metadata_list) == 1), f"seq_group_metadata_list: \
len({len(seq_group_metadata_list)}) - {seq_group_metadata_list}"

for seq_group_metadata in seq_group_metadata_list:
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
assert seq_group_metadata.is_prompt
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1
Expand All @@ -231,7 +228,7 @@ def _prepare_prompt(
data.num_prefill_tokens += len(tokens)
data.query_lens.append(len(tokens))
data.seq_lens.append(seq_len)
for i, pos in enumerate(data.input_positions[0]):
for pos in data.input_positions[i]:
block_number = block_table[pos // block_size]
block_offset = pos % block_size
data.slot_mapping.append(block_number)
Expand All @@ -240,6 +237,14 @@ def _prepare_prompt(
max_seq_len = max(data.seq_lens)
assert max_seq_len > 0

data.input_tokens = [[
token for tokens in data.input_tokens for token in tokens
]]
data.input_positions = [[
position for positions in data.input_positions
for position in positions
]]

dummy = num_blocks
# make_tensor_with_pad takes List[List[]] as input
# To make it work, input_block_ids is expanded
Expand All @@ -248,11 +253,11 @@ def _prepare_prompt(
pad=dummy,
dtype=torch.long,
device=self.device)
# input_block_ids gets back in here.
input_block_ids = input_block_ids.flatten().tolist()
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)
if data.num_prefills == 1:
input_block_ids = input_block_ids.flatten().tolist()
input_block_ids = torch.tensor(input_block_ids,
dtype=torch.long,
device=self.device)

prefill_size = (self.chunked_prefill_size if self.chunked_prefill else
1 << (math.ceil(math.log2(max_seq_len))))
Expand Down Expand Up @@ -612,13 +617,12 @@ def execute_model(

assert model_input.attn_metadata is not None
token_indices = None
num_prefills = model_input.attn_metadata.num_prefills
if get_pp_group().is_last_rank:
num_prefills = model_input.attn_metadata.num_prefills
selected_token_indices = \
model_input.sampling_metadata.selected_token_indices
len_token_indices = len(selected_token_indices)
if num_prefills > 0:
assert len_token_indices == 0 or len_token_indices == 1
num_prefill_tokens = \
model_input.attn_metadata.num_prefill_tokens
token_indices = torch.tensor(
Expand All @@ -637,7 +641,8 @@ def execute_model(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
intermediate_tensors=intermediate_tensors,
selected_token_indices=token_indices,
selected_token_indices=selected_token_indices
if num_prefills > 0 else token_indices,
**execute_model_kwargs,
)

Expand Down