diff --git a/vllm_rbln/attention/backends/flash_attention.py b/vllm_rbln/attention/backends/flash_attention.py index 689c9a711..c2579b99d 100644 --- a/vllm_rbln/attention/backends/flash_attention.py +++ b/vllm_rbln/attention/backends/flash_attention.py @@ -402,7 +402,6 @@ def build( if not envs.RBLN_FLASH_CAUSAL_ATTN: if input_data.num_prefills: step = steps[0][0] - assert input_data.num_prefills == 1 prefill_chunk_size = ( self.chunked_prefill_size if self.chunked_prefill else 1 << (math.ceil(math.log2(input_data.seq_lens[0])))) @@ -414,13 +413,19 @@ def build( max_seq_len, dtype=torch.float16 if self.enforce_eager else torch.float32) - causal_mask = 1 - torch.triu(torch.ones( - 1, 1, prefill_chunk_size, prefill_chunk_size), - diagonal=1) - if step >= prefill_chunk_size: - chunked_attention_mask[:, :, :, :, :step] = 1 + valid_len = sum(query_lens) + causal_masks = [ + torch.tril(torch.ones(query_len, query_len)) + for query_len in query_lens + ] + causal_mask = torch.block_diag(*causal_masks) + cur_causal_mask = torch.zeros(1, 1, prefill_chunk_size, + prefill_chunk_size) + cur_causal_mask[:, :, :valid_len, :valid_len] = causal_mask + if step > 0: + chunked_attention_mask[:, :, :, :query_lens[0], :step] = 1 chunked_attention_mask[:, :, :, :, step:step + - prefill_chunk_size] = causal_mask + prefill_chunk_size] = cur_causal_mask attn_masks = chunked_attention_mask else: decode_attention_mask = torch.zeros( diff --git a/vllm_rbln/core/scheduler.py b/vllm_rbln/core/scheduler.py index e09a52f6b..cdfed8fd5 100644 --- a/vllm_rbln/core/scheduler.py +++ b/vllm_rbln/core/scheduler.py @@ -28,6 +28,7 @@ SchedulerSwappedInOutputs, SchedulingBudget) from vllm.sequence import SequenceGroup, SequenceStage, SequenceStatus +import vllm_rbln.rbln_envs as envs from vllm_rbln.logger import init_logger logger = init_logger(__name__) @@ -388,9 +389,12 @@ def _schedule_prefills( ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) - # NOTE(RBLN): - # For rbln target, we only consider batch size of 1 for prefill. - break + if not enable_chunking or envs.RBLN_FLASH_CAUSAL_ATTN: + # NOTE(RBLN): + # For rbln target, we only consider batch size of 1 for prefill. + # In case of chunked prefill, it's available to schedule + # multiple requests in a single batch. + break logger.debug("waiting_queue -> len=%s", len(waiting_queue)) # Queue requests that couldn't be scheduled. diff --git a/vllm_rbln/worker/model_runner.py b/vllm_rbln/worker/model_runner.py index 86e65633d..a0b4a0e51 100644 --- a/vllm_rbln/worker/model_runner.py +++ b/vllm_rbln/worker/model_runner.py @@ -202,11 +202,8 @@ def _prepare_prompt( num_blocks_per_ve = num_blocks // \ self.runner.parallel_config.pipeline_parallel_size ve_offset = num_blocks_per_ve * virtual_engine - assert ( - len(seq_group_metadata_list) == 1), f"seq_group_metadata_list: \ - len({len(seq_group_metadata_list)}) - {seq_group_metadata_list}" - for seq_group_metadata in seq_group_metadata_list: + for i, seq_group_metadata in enumerate(seq_group_metadata_list): assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 @@ -231,7 +228,7 @@ def _prepare_prompt( data.num_prefill_tokens += len(tokens) data.query_lens.append(len(tokens)) data.seq_lens.append(seq_len) - for i, pos in enumerate(data.input_positions[0]): + for pos in data.input_positions[i]: block_number = block_table[pos // block_size] block_offset = pos % block_size data.slot_mapping.append(block_number) @@ -240,6 +237,14 @@ def _prepare_prompt( max_seq_len = max(data.seq_lens) assert max_seq_len > 0 + data.input_tokens = [[ + token for tokens in data.input_tokens for token in tokens + ]] + data.input_positions = [[ + position for positions in data.input_positions + for position in positions + ]] + dummy = num_blocks # make_tensor_with_pad takes List[List[]] as input # To make it work, input_block_ids is expanded @@ -248,11 +253,11 @@ def _prepare_prompt( pad=dummy, dtype=torch.long, device=self.device) - # input_block_ids gets back in here. - input_block_ids = input_block_ids.flatten().tolist() - input_block_ids = torch.tensor(input_block_ids, - dtype=torch.long, - device=self.device) + if data.num_prefills == 1: + input_block_ids = input_block_ids.flatten().tolist() + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) prefill_size = (self.chunked_prefill_size if self.chunked_prefill else 1 << (math.ceil(math.log2(max_seq_len)))) @@ -612,13 +617,12 @@ def execute_model( assert model_input.attn_metadata is not None token_indices = None + num_prefills = model_input.attn_metadata.num_prefills if get_pp_group().is_last_rank: - num_prefills = model_input.attn_metadata.num_prefills selected_token_indices = \ model_input.sampling_metadata.selected_token_indices len_token_indices = len(selected_token_indices) if num_prefills > 0: - assert len_token_indices == 0 or len_token_indices == 1 num_prefill_tokens = \ model_input.attn_metadata.num_prefill_tokens token_indices = torch.tensor( @@ -637,7 +641,8 @@ def execute_model( input_ids=model_input.input_tokens, positions=model_input.input_positions, intermediate_tensors=intermediate_tensors, - selected_token_indices=token_indices, + selected_token_indices=selected_token_indices + if num_prefills > 0 else token_indices, **execute_model_kwargs, )