Skip to content

Commit 807aca6

Browse files
authored
add: initial works for enabling warmup in v1 engine (#84)
* add: enable warmup in compile_or_warm_up_model Signed-off-by: Huijong JEONG <huijong.jeong@squeezebits.com> --------- Signed-off-by: Huijong JEONG <huijong.jeong@squeezebits.com>
1 parent a2c280e commit 807aca6

2 files changed

Lines changed: 133 additions & 4 deletions

File tree

vllm_rbln/v1/worker/rbln_model_runner.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@
3737
from vllm.model_executor import SamplingMetadata
3838
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3939
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
40-
from vllm.sampling_params import SamplingType
40+
from vllm.sampling_params import SamplingParams, SamplingType
4141
from vllm.sequence import IntermediateTensors
4242
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, LazyLoader, check_use_alibi,
4343
is_pin_memory_available, make_tensor_with_pad)
4444
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
45+
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
4546
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
4647
KVCacheConfig, KVCacheSpec,
4748
SlidingWindowSpec)
@@ -60,7 +61,6 @@
6061
if TYPE_CHECKING:
6162
import xgrammar as xgr
6263
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
63-
from vllm.v1.core.sched.output import SchedulerOutput
6464
else:
6565
xgr = LazyLoader("xgr", globals(), "xgrammar")
6666

@@ -253,6 +253,8 @@ def __init__(
253253
self.max_num_batched_tokens = (
254254
self.scheduler_config.max_num_batched_tokens)
255255

256+
self._accumulative_compilation_count = 0
257+
256258
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
257259
"""Update the cached states and the persistent batch with the scheduler
258260
output.
@@ -759,6 +761,111 @@ def compute_logits(
759761

760762
return self.model.compute_logits(hidden_states, sampling_metadata)
761763

764+
@torch.inference_mode()
765+
def warmup_model(self) -> None:
766+
# compile prefill graph
767+
prefill_seq_len = (self.scheduler_config.max_num_batched_tokens
768+
if self.scheduler_config.chunked_prefill_enabled
769+
else self.scheduler_config.max_model_len)
770+
dummy_prefill_schedule = SchedulerOutput(
771+
scheduled_new_reqs=[
772+
NewRequestData(
773+
req_id="dummy_prefill",
774+
prompt_token_ids=list(range(prefill_seq_len)),
775+
mm_hashes=[],
776+
mm_inputs=[],
777+
mm_positions=[],
778+
sampling_params=SamplingParams(temperature=0.0),
779+
block_ids=([0], ),
780+
num_computed_tokens=0,
781+
lora_request=None,
782+
)
783+
],
784+
scheduled_cached_reqs=[],
785+
num_scheduled_tokens={"dummy_prefill": prefill_seq_len},
786+
total_num_scheduled_tokens=prefill_seq_len,
787+
scheduled_spec_decode_tokens={},
788+
scheduled_encoder_inputs={},
789+
num_common_prefix_blocks=[0],
790+
finished_req_ids=set(),
791+
free_encoder_input_ids=[],
792+
structured_output_request_ids={},
793+
grammar_bitmask=None,
794+
kv_connector_metadata=None)
795+
dummy_prefill_cleanup = SchedulerOutput(
796+
scheduled_new_reqs=[],
797+
scheduled_cached_reqs=[],
798+
num_scheduled_tokens={},
799+
total_num_scheduled_tokens=0,
800+
scheduled_spec_decode_tokens={},
801+
scheduled_encoder_inputs={},
802+
num_common_prefix_blocks=[1],
803+
finished_req_ids={
804+
"dummy_prefill",
805+
},
806+
free_encoder_input_ids=[],
807+
structured_output_request_ids={},
808+
grammar_bitmask=None,
809+
kv_connector_metadata=None)
810+
self.execute_model(dummy_prefill_schedule)
811+
self.execute_model(dummy_prefill_cleanup)
812+
813+
num_prefill_graphs = self._accumulative_compilation_count
814+
logger.info("Compiled %d graph(s) for prefill", num_prefill_graphs)
815+
816+
# compile decode graph
817+
decode_max_batch_size = self.scheduler_config.max_num_seqs
818+
decode_max_seq_len = self.scheduler_config.max_model_len
819+
decode_max_num_blocks = (decode_max_seq_len + self.block_size -
820+
1) // self.block_size
821+
dummy_decode_schedule = SchedulerOutput(
822+
scheduled_new_reqs=[
823+
NewRequestData(
824+
req_id=f"dummy_decode_{i}",
825+
prompt_token_ids=list(range(decode_max_seq_len - 1)),
826+
mm_hashes=[],
827+
mm_inputs=[],
828+
mm_positions=[],
829+
sampling_params=SamplingParams(temperature=0.0),
830+
block_ids=([0] * decode_max_num_blocks, ),
831+
num_computed_tokens=decode_max_seq_len - 1,
832+
lora_request=None,
833+
) for i in range(decode_max_batch_size)
834+
],
835+
scheduled_cached_reqs=[],
836+
num_scheduled_tokens={
837+
f"dummy_decode_{i}": 1
838+
for i in range(decode_max_batch_size)
839+
},
840+
total_num_scheduled_tokens=decode_max_batch_size,
841+
scheduled_spec_decode_tokens={},
842+
scheduled_encoder_inputs={},
843+
num_common_prefix_blocks=[0],
844+
finished_req_ids=set(),
845+
free_encoder_input_ids=[],
846+
structured_output_request_ids={},
847+
grammar_bitmask=None,
848+
kv_connector_metadata=None)
849+
dummy_decode_cleanup = SchedulerOutput(
850+
scheduled_new_reqs=[],
851+
scheduled_cached_reqs=[],
852+
num_scheduled_tokens={},
853+
total_num_scheduled_tokens=0,
854+
scheduled_spec_decode_tokens={},
855+
scheduled_encoder_inputs={},
856+
num_common_prefix_blocks=[1],
857+
finished_req_ids=set(f"dummy_decode_{i}"
858+
for i in range(decode_max_batch_size)),
859+
free_encoder_input_ids=[],
860+
structured_output_request_ids={},
861+
grammar_bitmask=None,
862+
kv_connector_metadata=None)
863+
self.execute_model(dummy_decode_schedule)
864+
self.execute_model(dummy_decode_cleanup)
865+
866+
logger.info("Compiled %d graph(s) for decode",
867+
self._accumulative_compilation_count - num_prefill_graphs)
868+
762869
@torch.inference_mode()
763870
def execute_model(
764871
self,
@@ -1155,6 +1262,22 @@ def execute_model(
11551262
if has_kv_transfer_group():
11561263
get_kv_transfer_group().clear_connector_metadata()
11571264

1265+
compilation_metrics = torch._dynamo.utils.get_compilation_metrics()
1266+
if len(compilation_metrics) > self._accumulative_compilation_count:
1267+
new_compilation_metrics = compilation_metrics[
1268+
self._accumulative_compilation_count:]
1269+
reasons = ", ".join([
1270+
cm.recompile_reason or "initial compilation"
1271+
for cm in new_compilation_metrics
1272+
])
1273+
logger.debug(
1274+
"graph compilation(s) triggered due to following reason(s): %s",
1275+
reasons)
1276+
self._accumulative_compilation_count += len(
1277+
new_compilation_metrics)
1278+
logger.debug("accumulative compilation count: %s",
1279+
self._accumulative_compilation_count)
1280+
11581281
return ModelRunnerOutput(
11591282
req_ids=self.input_batch.req_ids,
11601283
req_id_to_index=self.input_batch.req_id_to_index,
@@ -1633,6 +1756,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
16331756
if has_kv_transfer_group():
16341757
get_kv_transfer_group().register_kv_caches(kv_caches)
16351758

1759+
self.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
1760+
self.cache_config.num_cpu_blocks = 0
1761+
16361762
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
16371763
"""
16381764
Generates the KVCacheSpec by parsing the kv cache format from each

vllm_rbln/v1/worker/rbln_worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,11 @@ def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
212212
self.model_runner.initialize_kv_cache(kv_cache_config)
213213

214214
def compile_or_warm_up_model(self) -> None:
215-
logger.warning("model warm-up is not supported on RBLN.")
216-
pass
215+
if self.model_config.enforce_eager or not envs.RBLN_COMPILE_MODEL:
216+
logger.warning("skipping compile_or_warm_up_model")
217+
return
218+
219+
self.model_runner.warmup_model()
217220

218221
def get_model(self) -> nn.Module:
219222
return self.model_runner.get_model()

0 commit comments

Comments
 (0)