Skip to content

Commit c45f2fa

Browse files
feat: compile lm_head using rbln backend (#96) (#99)
* feat: compile lm_head using rbln backend (#96) * feat: optimize compute_logits compilation with TP and multimodal support --------- Co-authored-by: pei0033 <parkeunik@naver.com> Co-authored-by: pei0033 <eunik.park@squeezebits.com>
1 parent d7bc737 commit c45f2fa

1 file changed

Lines changed: 34 additions & 2 deletions

File tree

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
get_pp_group, get_tp_group, prepare_communication_buffer_for_model)
3535
from vllm.forward_context import (DPMetadata, get_forward_context,
3636
set_forward_context)
37+
from vllm.model_executor import SamplingMetadata
3738
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3839
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
3940
from vllm.sampling_params import SamplingType
@@ -750,6 +751,14 @@ def get_dp_padding(self,
750751
)
751752
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
752753

754+
def compute_logits(
755+
self,
756+
hidden_states: torch.Tensor,
757+
sampling_metadata: Optional[SamplingMetadata] = None,
758+
) -> torch.Tensor:
759+
760+
return self.model.compute_logits(hidden_states, sampling_metadata)
761+
753762
@torch.inference_mode()
754763
def execute_model(
755764
self,
@@ -920,8 +929,15 @@ def execute_model(
920929
all_gather_group=get_tp_group())
921930
logits = None
922931
else:
923-
sample_hidden_states = hidden_states[logits_indices]
924-
logits = self.model.compute_logits(sample_hidden_states, None)
932+
if is_prefills[0]: # prefill
933+
sample_hidden_states = hidden_states[logits_indices]
934+
logits = self.compute_logits(sample_hidden_states, None)
935+
else: # decode
936+
logits = self.compute_logits(hidden_states, None)
937+
logits = logits[logits_indices]
938+
logits = self.logits_processor._gather_logits(logits)
939+
logits = logits.view(-1, logits.size(-1))
940+
925941
if broadcast_pp_output:
926942
model_output_broadcast_data = ({
927943
"logits": logits.contiguous(),
@@ -1215,6 +1231,13 @@ def load_model(self) -> None:
12151231
self.model_config.get_num_layers(self.parallel_config),
12161232
)
12171233

1234+
# get logits processor from model
1235+
if self.model_config.is_multimodal_model and hasattr(
1236+
self.model.get_language_model(), "logits_processor"):
1237+
self.logits_processor = self.model.get_language_model(
1238+
).logits_processor
1239+
else:
1240+
self.logits_processor = self.model.logits_processor
12181241
# if self.lora_config:
12191242
# self.model = self.load_lora_model(
12201243
# self.model,
@@ -1250,6 +1273,15 @@ def load_model(self) -> None:
12501273

12511274
self.compile_context = CompileContext(use_weight_sharing=True)
12521275
self.model_executable = self._compile_model(self.model)
1276+
self.compute_logits = torch.compile(
1277+
self.compute_logits,
1278+
backend="rbln",
1279+
options={
1280+
"compile_context": self.compile_context,
1281+
"tensor_parallel_size": envs.RBLN_TP_SIZE,
1282+
},
1283+
dynamic=False,
1284+
)
12531285

12541286
def save_tensorized_model(
12551287
self,

0 commit comments

Comments
 (0)