|
34 | 34 | get_pp_group, get_tp_group, prepare_communication_buffer_for_model) |
35 | 35 | from vllm.forward_context import (DPMetadata, get_forward_context, |
36 | 36 | set_forward_context) |
| 37 | +from vllm.model_executor import SamplingMetadata |
37 | 38 | from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding |
38 | 39 | from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader |
39 | 40 | from vllm.sampling_params import SamplingType |
@@ -750,6 +751,14 @@ def get_dp_padding(self, |
750 | 751 | ) |
751 | 752 | return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding |
752 | 753 |
|
| 754 | + def compute_logits( |
| 755 | + self, |
| 756 | + hidden_states: torch.Tensor, |
| 757 | + sampling_metadata: Optional[SamplingMetadata] = None, |
| 758 | + ) -> torch.Tensor: |
| 759 | + |
| 760 | + return self.model.compute_logits(hidden_states, sampling_metadata) |
| 761 | + |
753 | 762 | @torch.inference_mode() |
754 | 763 | def execute_model( |
755 | 764 | self, |
@@ -920,8 +929,15 @@ def execute_model( |
920 | 929 | all_gather_group=get_tp_group()) |
921 | 930 | logits = None |
922 | 931 | else: |
923 | | - sample_hidden_states = hidden_states[logits_indices] |
924 | | - logits = self.model.compute_logits(sample_hidden_states, None) |
| 932 | + if is_prefills[0]: # prefill |
| 933 | + sample_hidden_states = hidden_states[logits_indices] |
| 934 | + logits = self.compute_logits(sample_hidden_states, None) |
| 935 | + else: # decode |
| 936 | + logits = self.compute_logits(hidden_states, None) |
| 937 | + logits = logits[logits_indices] |
| 938 | + logits = self.logits_processor._gather_logits(logits) |
| 939 | + logits = logits.view(-1, logits.size(-1)) |
| 940 | + |
925 | 941 | if broadcast_pp_output: |
926 | 942 | model_output_broadcast_data = ({ |
927 | 943 | "logits": logits.contiguous(), |
@@ -1215,6 +1231,13 @@ def load_model(self) -> None: |
1215 | 1231 | self.model_config.get_num_layers(self.parallel_config), |
1216 | 1232 | ) |
1217 | 1233 |
|
| 1234 | + # get logits processor from model |
| 1235 | + if self.model_config.is_multimodal_model and hasattr( |
| 1236 | + self.model.get_language_model(), "logits_processor"): |
| 1237 | + self.logits_processor = self.model.get_language_model( |
| 1238 | + ).logits_processor |
| 1239 | + else: |
| 1240 | + self.logits_processor = self.model.logits_processor |
1218 | 1241 | # if self.lora_config: |
1219 | 1242 | # self.model = self.load_lora_model( |
1220 | 1243 | # self.model, |
@@ -1250,6 +1273,15 @@ def load_model(self) -> None: |
1250 | 1273 |
|
1251 | 1274 | self.compile_context = CompileContext(use_weight_sharing=True) |
1252 | 1275 | self.model_executable = self._compile_model(self.model) |
| 1276 | + self.compute_logits = torch.compile( |
| 1277 | + self.compute_logits, |
| 1278 | + backend="rbln", |
| 1279 | + options={ |
| 1280 | + "compile_context": self.compile_context, |
| 1281 | + "tensor_parallel_size": envs.RBLN_TP_SIZE, |
| 1282 | + }, |
| 1283 | + dynamic=False, |
| 1284 | + ) |
1253 | 1285 |
|
1254 | 1286 | def save_tensorized_model( |
1255 | 1287 | self, |
|
0 commit comments