Skip to content

Commit e75d23a

Browse files
other: Add warmup runs when running with RBLNSampler() (#37)
* add warm up for sampler * refacotr * [skip ci] change default value * [skip ci] update
1 parent b2531e3 commit e75d23a

4 files changed

Lines changed: 219 additions & 25 deletions

File tree

vllm_rbln/rbln_envs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
if TYPE_CHECKING:
2121
RBLN_COMPILE_MODEL: bool = True
2222
RBLN_TP_SIZE: int = 1
23+
RBLN_SAMPLER: bool = False
24+
RBLN_ENABLE_WARM_UP: bool = False
2325

2426
# extended environments
2527
environment_variables = {
@@ -31,6 +33,14 @@
3133
# TP Size for RSD.
3234
"RBLN_TP_SIZE":
3335
lambda: int(os.environ.get("TP_SIZE", 1)),
36+
# Use customized sampler
37+
"RBLN_SAMPLER":
38+
(lambda: os.environ.get("VLLM_RBLN_SAMPLER", "False").lower() in
39+
("true", "1")),
40+
# Enable warmup
41+
"RBLN_ENABLE_WARM_UP":
42+
(lambda: os.environ.get("VLLM_RBLN_ENABLE_WARM_UP", "False").lower() in
43+
("true", "1")),
3444
}
3545

3646

vllm_rbln/v1/sample/sampler.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _rbln_top_p_sample_only(
8989
sorted_probs = torch.softmax(sorted_logits, dim=-1)
9090
return sorted_probs
9191

92+
@torch.compiler.disable
9293
def apply_penalties(
9394
self,
9495
logits: torch.Tensor,
@@ -105,3 +106,83 @@ def apply_penalties(
105106
sampling_metadata.output_token_ids,
106107
)
107108
return logits
109+
110+
111+
WARM_UP_CONFIGS = [
112+
{
113+
"name": "no_penalty_greedy",
114+
"no_penalties": True,
115+
"all_greedy": True,
116+
"all_random": False,
117+
"temperature": 0.0
118+
},
119+
{
120+
"name": "no_penalty_topp",
121+
"no_penalties": True,
122+
"all_greedy": False,
123+
"all_random": True,
124+
"top_p": 0.9,
125+
"temperature": 0.5
126+
},
127+
{
128+
"name": "no_penalty_topk",
129+
"no_penalties": True,
130+
"all_greedy": False,
131+
"all_random": True,
132+
"top_k": 1.0,
133+
"temperature": 0.5
134+
},
135+
{
136+
"name": "no_penalty_topp_topk",
137+
"no_penalties": True,
138+
"all_greedy": False,
139+
"all_random": True,
140+
"top_p": 0.9,
141+
"top_k": 1.0,
142+
"temperature": 0.5
143+
},
144+
{
145+
"name": "penalty_greedy",
146+
"no_penalties": False,
147+
"frequency_penalties": 0.1,
148+
"presence_penalties": 0.1,
149+
"repetition_penalties": 1.0,
150+
"all_greedy": True,
151+
"all_random": False,
152+
"temperature": 0.0
153+
},
154+
{
155+
"name": "penalty_topp",
156+
"no_penalties": False,
157+
"frequency_penalties": 0.1,
158+
"presence_penalties": 0.1,
159+
"repetition_penalties": 1.0,
160+
"all_greedy": False,
161+
"all_random": True,
162+
"top_p": 0.9,
163+
"temperature": 0.5
164+
},
165+
{
166+
"name": "penalty_topk",
167+
"no_penalties": False,
168+
"frequency_penalties": 0.1,
169+
"presence_penalties": 0.1,
170+
"repetition_penalties": 1.0,
171+
"all_greedy": False,
172+
"all_random": True,
173+
"top_k": 1.0,
174+
"temperature": 0.5
175+
},
176+
{
177+
"name": "penalty_topp_topk",
178+
"no_penalties": False,
179+
"frequency_penalties": 0.1,
180+
"presence_penalties": 0.1,
181+
"repetition_penalties": 1.0,
182+
"all_greedy": False,
183+
"all_random": True,
184+
"top_p": 0.9,
185+
"top_k": 1.0,
186+
"temperature": 0.5
187+
},
188+
]

vllm_rbln/v1/worker/optimum_model_runner.py

Lines changed: 117 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# http://www.apache.org/licenses/LICENSE-2.0
88

9-
import os
109
# Unless required by applicable law or agreed to in writing, software
1110
# distributed under the License is distributed on an "AS IS" BASIS,
1211
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -29,28 +28,23 @@
2928
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
3029
KVCacheSpec)
3130
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
31+
from vllm.v1.sample.sampler import Sampler
3232
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
3333
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
3434

35+
import vllm_rbln.rbln_envs as envs
3536
from vllm_rbln.logger import init_logger
3637
from vllm_rbln.model_executor.model_loader.rbln_model_loader import (
3738
get_optimum_model)
3839
from vllm_rbln.model_executor.models.optimum import (ModelInputForRBLN,
3940
RBLNOptimumDictTableMixin)
40-
from vllm_rbln.v1.sample.sampler import Sampler
41+
from vllm_rbln.v1.sample.sampler import WARM_UP_CONFIGS
4142
from vllm_rbln.v1.sample.sampler import Sampler as RBLNSampler
4243
from vllm_rbln.v1.worker.multimodal import RBLNOptimumMultiModalKwargs
4344

4445
logger = init_logger(__name__)
4546

4647

47-
def _use_rbln_sampler() -> bool:
48-
"""Check if RBLN sampler should be used based on environment variable."""
49-
TRUTHY_VALUES = frozenset({"1", "true", "yes", "on"})
50-
return os.environ.get("VLLM_RBLN_SAMPLER",
51-
"").strip().lower() in TRUTHY_VALUES
52-
53-
5448
class RBLNOptimumModelRunner(LoRAModelRunnerMixin):
5549

5650
def __init__(self, vllm_config: VllmConfig, device: torch.device):
@@ -101,21 +95,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
10195
self.mm_registry = MULTIMODAL_REGISTRY
10296
self.uses_mrope = model_config.uses_mrope
10397

104-
# encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
105-
# model_config=model_config,
106-
# scheduler_config=scheduler_config,
107-
# mm_registry=self.mm_registry,
108-
# )
109-
# self.max_num_encoder_input_tokens = encoder_compute_budget
110-
# self.encoder_cache_size = encoder_cache_size
111-
11298
# Sampler
113-
use_rbln_sampler = _use_rbln_sampler()
114-
logger.info("Using RBLN sampler: %s", use_rbln_sampler)
115-
116-
sampler = RBLNSampler() if use_rbln_sampler else Sampler()
117-
118-
if use_rbln_sampler:
99+
self.use_rbln_sampler = envs.RBLN_SAMPLER
100+
logger.info("Using RBLN sampler: %s", self.use_rbln_sampler)
101+
sampler = RBLNSampler() if self.use_rbln_sampler else Sampler()
102+
if self.use_rbln_sampler:
119103
# Use torch.compile for optimized RBLN sampler
120104
sampler = torch.compile(sampler, dynamic=False, fullgraph=False)
121105

@@ -622,3 +606,113 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
622606

623607
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
624608
pass
609+
610+
def dummy_sampler_run(self):
611+
if not self.use_rbln_sampler:
612+
return
613+
614+
def set_sampling_tensors(input_batch, **params):
615+
input_batch.temperature_cpu_tensor.fill_(params["temperature"])
616+
input_batch.temperature.fill_(params["temperature"])
617+
618+
optional_keys = [
619+
("top_p", input_batch.top_p_cpu_tensor, input_batch.top_p),
620+
("top_k", input_batch.top_k_cpu_tensor, input_batch.top_k),
621+
("min_p", input_batch.min_p_cpu_tensor, input_batch.min_p),
622+
("frequency_penalties",
623+
input_batch.frequency_penalties_cpu_tensor,
624+
input_batch.frequency_penalties),
625+
("presence_penalties",
626+
input_batch.presence_penalties_cpu_tensor,
627+
input_batch.presence_penalties),
628+
("repetition_penalties",
629+
input_batch.repetition_penalties_cpu_tensor,
630+
input_batch.repetition_penalties),
631+
]
632+
633+
for key, cpu_tensor, dev_tensor in optional_keys:
634+
val = params.get(key)
635+
if val is not None:
636+
cpu_tensor.fill_(val)
637+
dev_tensor.fill_(val)
638+
639+
def populate_reqs(input_batch, base_config, batch_size):
640+
for i in range(batch_size):
641+
req_id = f"{base_config['name']}_req_{i}"
642+
input_batch._req_ids.append(req_id)
643+
input_batch.req_id_to_index[req_id] = i
644+
645+
if base_config["all_greedy"]:
646+
input_batch.greedy_reqs.add(req_id)
647+
elif base_config["all_random"]:
648+
input_batch.random_reqs.add(req_id)
649+
650+
for attr, req_set in [
651+
("top_p", input_batch.top_p_reqs),
652+
("top_k", input_batch.top_k_reqs),
653+
("frequency_penalties",
654+
input_batch.frequency_penalties_reqs),
655+
("repetition_penalties",
656+
input_batch.repetition_penalties_reqs),
657+
("presence_penalties",
658+
input_batch.presence_penalties_reqs),
659+
]:
660+
if base_config.get(attr) is not None:
661+
req_set.add(req_id)
662+
663+
def clear_reqs(input_batch):
664+
input_batch._req_ids.clear()
665+
input_batch.req_id_to_index.clear()
666+
input_batch.greedy_reqs.clear()
667+
input_batch.random_reqs.clear()
668+
input_batch.top_p_reqs.clear()
669+
input_batch.top_k_reqs.clear()
670+
input_batch.frequency_penalties_reqs.clear()
671+
input_batch.repetition_penalties_reqs.clear()
672+
input_batch.presence_penalties_reqs.clear()
673+
674+
def dummy_run_batches(base_config):
675+
for batch_size in range(1, self.input_batch.max_num_reqs + 1):
676+
input_batch = self.input_batch
677+
populate_reqs(input_batch, base_config, batch_size)
678+
679+
metadata = input_batch._make_sampling_metadata()
680+
metadata.no_penalties = base_config["no_penalties"]
681+
metadata.all_greedy = base_config["all_greedy"]
682+
metadata.all_random = base_config["all_random"]
683+
684+
if (not metadata.no_penalties
685+
and metadata.prompt_token_ids is None):
686+
metadata.prompt_token_ids = torch.zeros((batch_size, 1),
687+
dtype=torch.long,
688+
device="cpu")
689+
690+
logger.info(
691+
"Running dummy compile with batch_size=%d, vocab_size=%d",
692+
batch_size, input_batch.vocab_size)
693+
logger.info("Sampling metadata: %s", metadata)
694+
695+
with torch.inference_mode():
696+
empty_logits = torch.empty(batch_size,
697+
input_batch.vocab_size,
698+
dtype=torch.float32)
699+
_ = self.sampler(logits=empty_logits,
700+
sampling_metadata=metadata)
701+
702+
clear_reqs(input_batch)
703+
704+
for config in WARM_UP_CONFIGS:
705+
logger.info("Running dummy sampler config: %s", config["name"])
706+
707+
set_sampling_tensors(
708+
self.input_batch,
709+
temperature=config["temperature"],
710+
top_p=config.get("top_p"),
711+
top_k=config.get("top_k"),
712+
min_p=config.get("min_p"),
713+
frequency_penalties=config.get("frequency_penalties"),
714+
repetition_penalties=config.get("repetition_penalties"),
715+
presence_penalties=config.get("presence_penalties"),
716+
)
717+
718+
dummy_run_batches(config)

vllm_rbln/v1/worker/optimum_worker.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm.v1.outputs import ModelRunnerOutput
2929
from vllm.v1.worker.worker_base import WorkerBase
3030

31+
import vllm_rbln.rbln_envs as envs
3132
from vllm_rbln.v1.worker.optimum_model_runner import RBLNOptimumModelRunner
3233

3334
logger = init_logger(__name__)
@@ -113,8 +114,16 @@ def compile_or_warm_up_model(self) -> None:
113114
# Reset the seed to ensure that the random state is not affected by
114115
# the model initialization and profiling.
115116
set_random_seed(self.model_config.seed)
116-
# TODO(eunji): warmup is required?
117-
# self.model_runner.warming_up_model()
117+
118+
if not envs.RBLN_ENABLE_WARM_UP:
119+
logger.info(
120+
"Warm up is disabled. " \
121+
"Set VLLM_RBLN_ENABLE_WARM_UP=1 to enable warm up."
122+
)
123+
return
124+
125+
logger.info("Running dummy warm up.")
126+
self.model_runner.dummy_sampler_run()
118127

119128
def get_model(self) -> nn.Module:
120129
return self.model_runner.get_model()

0 commit comments

Comments
 (0)