Skip to content

Commit 27aa201

Browse files
authored
refactor: update variable names and logging for better code clarity (#105)
* refactor: update variable name, type annotations and improve readability * refactor: convert info logging to debug level
1 parent 32f57ba commit 27aa201

4 files changed

Lines changed: 50 additions & 50 deletions

File tree

vllm_rbln/attention/backends/flash_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,15 +453,15 @@ def build(
453453
attn_masks=attn_masks,
454454
kv_caches=None,
455455
)
456-
logger.info("RBLNAttentionMetadata = %s", attn_metadata)
457-
logger.info("\tslot_mapping size = %s", slot_mapping.size())
458-
logger.info("\tblock_tables size = %s", block_tables.size())
456+
logger.debug("RBLNAttentionMetadata = %s", attn_metadata)
457+
logger.debug("\tslot_mapping size = %s", slot_mapping.size())
458+
logger.debug("\tblock_tables size = %s", block_tables.size())
459459
if not envs.RBLN_FLASH_CAUSAL_ATTN and attn_masks is not None:
460-
logger.info("\tattn_masks size = %s", attn_masks.size())
461-
logger.info("\tattn_masks = %s", attn_masks[:, :, :, :, :32])
460+
logger.debug("\tattn_masks size = %s", attn_masks.size())
461+
logger.debug("\tattn_masks = %s", attn_masks[:, :, :, :, :32])
462462
else:
463463
assert attn_masks is None
464-
logger.info("\tseq_lens_tensor size= %s", seq_lens_tensor.size())
464+
logger.debug("\tseq_lens_tensor size= %s", seq_lens_tensor.size())
465465
return attn_metadata
466466

467467

vllm_rbln/v1/attention/backends/flash_attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,12 @@ def build(
514514
kv_caches=None,
515515
)
516516

517-
logger.info("RBLNAttentionMetadata = %s", attn_metadata)
518-
logger.info("\tslot_mapping size = %s", slot_mapping.size())
519-
logger.info("\tblock_tables size = %s", block_tables_tensor.size())
520-
logger.info("\tattn_masks size = %s", attn_masks.size())
521-
logger.info("\tattn_masks = %s", attn_masks[:, :, :, :, :32])
522-
logger.info("\tseq_lens_tensor size= %s", seq_lens_tensor.size())
517+
logger.debug("RBLNAttentionMetadata = %s", attn_metadata)
518+
logger.debug("\tslot_mapping size = %s", slot_mapping.size())
519+
logger.debug("\tblock_tables size = %s", block_tables_tensor.size())
520+
logger.debug("\tattn_masks size = %s", attn_masks.size())
521+
logger.debug("\tattn_masks = %s", attn_masks[:, :, :, :, :32])
522+
logger.debug("\tseq_lens_tensor size= %s", seq_lens_tensor.size())
523523
return attn_metadata
524524

525525
def use_cascade_attention(self, *args, **kwargs) -> bool:

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -653,19 +653,19 @@ def _prepare_inputs(
653653
# Hot-Swap lora model
654654
# if self.lora_config:
655655
# self.set_active_loras(self.input_batch, num_scheduled_tokens)
656-
logger.info("num_reqs: %s", num_reqs)
657-
logger.info("token_indices: %s", token_indices)
658-
logger.info("input_batch: %s", vars(self.input_batch))
659-
logger.info(
656+
logger.debug("num_reqs: %s", num_reqs)
657+
logger.debug("token_indices: %s", token_indices)
658+
logger.debug("input_batch: %s", vars(self.input_batch))
659+
logger.debug(
660660
"input_ids: %s",
661661
self.input_ids[:scheduler_output.total_num_scheduled_tokens],
662662
)
663-
logger.info(
663+
logger.debug(
664664
"positions: %s",
665665
self.positions[:scheduler_output.total_num_scheduled_tokens],
666666
)
667-
logger.info("attn_metadata: %s", next(iter(attn_metadata.items())))
668-
logger.info("logits_indices: %s", logits_indices)
667+
logger.debug("attn_metadata: %s", next(iter(attn_metadata.items())))
668+
logger.debug("logits_indices: %s", logits_indices)
669669
return attn_metadata, logits_indices, spec_decode_metadata
670670

671671
def _compile_model(self, model):

vllm_rbln/worker/model_runner.py

Lines changed: 31 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,12 @@ def _prepare_prompt(
267267
dtype=torch.long,
268268
device=self.device)
269269

270-
logger.info("[RBLN] model input builder, prepare_prompt")
271-
logger.info("\tpadded input_tokens = %s", input_tokens)
272-
logger.info("\tpadded input_positions = %s", input_positions)
273-
logger.info("\tinput_block_ids = %s", input_block_ids)
274-
logger.info("\tseq_lens = %s", data.seq_lens)
275-
logger.info("\tquery_lens = %s", data.query_lens)
270+
logger.debug("[RBLN] model input builder, prepare_prompt")
271+
logger.debug("\tpadded input_tokens = %s", input_tokens)
272+
logger.debug("\tpadded input_positions = %s", input_positions)
273+
logger.debug("\tinput_block_ids = %s", input_block_ids)
274+
logger.debug("\tseq_lens = %s", data.seq_lens)
275+
logger.debug("\tquery_lens = %s", data.query_lens)
276276
return (input_tokens, input_positions, input_block_ids)
277277

278278
def _prepare_decode(
@@ -340,12 +340,12 @@ def _prepare_decode(
340340
dtype=torch.long,
341341
device=self.device)
342342

343-
logger.info("[RBLN] model input builder, prepare_decode")
344-
logger.info("\tpadded input_tokens = %s", data.input_tokens)
345-
logger.info("\tpadded input_positions = %s", data.input_positions)
346-
logger.info("\tinput_block_ids = %s", input_block_ids)
347-
logger.info("\tseq_lens = %s", data.seq_lens)
348-
logger.info("\tquery_lens = %s", data.query_lens)
343+
logger.debug("[RBLN] model input builder, prepare_decode")
344+
logger.debug("\tpadded input_tokens = %s", data.input_tokens)
345+
logger.debug("\tpadded input_positions = %s", data.input_positions)
346+
logger.debug("\tinput_block_ids = %s", input_block_ids)
347+
logger.debug("\tseq_lens = %s", data.seq_lens)
348+
logger.debug("\tquery_lens = %s", data.query_lens)
349349

350350
assert input_tokens.shape[0] == self.max_num_seqs
351351
assert input_positions.shape[0] == self.max_num_seqs
@@ -520,10 +520,10 @@ def model_wrapper(
520520
model_output = model_output[:, selected_token_indices]
521521
logits = self.compute_logits_model.compute_logits(
522522
model_output, None)
523+
return logits
523524
else:
524525
# non last rank create intermediate tensors, bypass it
525-
logits = model_output
526-
return logits
526+
return model_output
527527

528528
if self.model_config.enforce_eager or not envs.RBLN_COMPILE_MODEL:
529529
self.model_executable = model_wrapper
@@ -583,9 +583,9 @@ def prepare_model_input(
583583

584584
is_prompt = seq_group_metadata_list[
585585
0].is_prompt if seq_group_metadata_list else None
586-
logger.info("[RBLN] num_requests = %d", len(seq_group_metadata_list))
587-
logger.info("[RBLN] input_ids = %s", model_input.input_tokens)
588-
logger.info("[RBLN] positions = %s", model_input.input_positions)
586+
logger.debug("[RBLN] num_requests = %d", len(seq_group_metadata_list))
587+
logger.debug("[RBLN] input_ids = %s", model_input.input_tokens)
588+
logger.debug("[RBLN] positions = %s", model_input.input_positions)
589589
return dataclasses.replace(model_input,
590590
sampling_metadata=sampling_metadata,
591591
virtual_engine=virtual_engine,
@@ -594,12 +594,12 @@ def prepare_model_input(
594594
@torch.inference_mode()
595595
def execute_model(
596596
self,
597-
model_input: ModelInputForRebel,
597+
model_input: ModelInputForRebelWithSamplingMetadata,
598598
kv_caches: Optional[List[torch.Tensor]] = None,
599599
intermediate_tensors: Optional[IntermediateTensors] = None,
600600
num_steps: int = 1,
601601
previous_hidden_states: Optional[torch.Tensor] = None,
602-
) -> Optional[SamplerOutput]:
602+
) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
603603
assert kv_caches is not None
604604
if num_steps > 1:
605605
raise ValueError(
@@ -613,6 +613,7 @@ def execute_model(
613613
assert model_input.attn_metadata is not None
614614
token_indices = None
615615
if get_pp_group().is_last_rank:
616+
assert model_input.sampling_metadata is not None
616617
num_prefills = model_input.attn_metadata.num_prefills
617618
selected_token_indices = \
618619
model_input.sampling_metadata.selected_token_indices
@@ -633,30 +634,29 @@ def execute_model(
633634
if model_input.attn_metadata is not None:
634635
model_input.attn_metadata.kv_caches = kv_caches
635636

636-
hidden_states = self.model_executable(
637+
logits_or_intermediate_states = self.model_executable(
637638
input_ids=model_input.input_tokens,
638639
positions=model_input.input_positions,
639640
intermediate_tensors=intermediate_tensors,
640641
selected_token_indices=token_indices,
641642
**execute_model_kwargs,
642643
)
643644

644-
if get_pp_group().is_last_rank:
645-
# Gather logits for TP
646-
logits_processor = self.compute_logits_model.logits_processor
647-
hidden_states = logits_processor._gather_logits(hidden_states)
648-
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
645+
if get_pp_group().is_last_rank:
646+
# Gather logits for TP
647+
logits_processor = self.compute_logits_model.logits_processor
648+
logits = logits_processor._gather_logits(
649+
logits_or_intermediate_states)
650+
logits = logits.view(-1, logits.size(-1))
649651

650-
if not get_pp_group().is_last_rank:
651-
intermediate_states = hidden_states
652+
else:
653+
intermediate_states = logits_or_intermediate_states
652654
assert isinstance(intermediate_states, IntermediateTensors)
653655
return intermediate_states
654656

655657
# Compute the logits. -> moved to model executable
656-
if num_prefills > 0 and len_token_indices != 0:
657-
logits = hidden_states
658-
else:
659-
logits = hidden_states[selected_token_indices]
658+
if not (num_prefills > 0 and len_token_indices != 0):
659+
logits = logits[selected_token_indices]
660660

661661
# Only perform sampling in the driver worker.
662662
if not self.is_driver_worker:

0 commit comments

Comments
 (0)