Skip to content

Commit 9d049db

Browse files
fix: use sampling during warmup and disable backed_size_oblivious after model compilation (#551)
# Description A couple of changes related to compilation of operations during sampling. - the [`batched_count_greater_than` function](https://github.com/vllm-project/vllm/blob/b8b302cde434df8c9289a2b465406b47ebab1c2d/vllm/v1/sample/ops/logprobs.py#L11) requires compilation and is used to compute logprobs. - we also found that this function will fail to compile in Pytorch 2.8.0 and 2.9.0 when `backed_size_oblivious` is enabled, so this PR disables `backed_size_oblivious` after model compilation. ## Related Issues Fixes #550 --------- Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
1 parent f081f4f commit 9d049db

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

vllm_spyre/v1/worker/spyre_worker.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def __stagger_exit__(*args, **kwargs):
103103
_inside_warmup_mode = False
104104

105105

106+
@contextlib.contextmanager
107+
def use_torch_fx_backed_size_oblivious():
108+
# this setting is required to mark a dimension of size 1 as dynamic
109+
# for pytorch >= 2.7.1 (needed to support batch size 1 for decodes)
110+
# NB: this setting is disabled at the end of this function
111+
from torch.fx.experimental import _config as config
112+
config.backed_size_oblivious = True
113+
yield
114+
config.backed_size_oblivious = False
115+
116+
106117
class SpyreWorker(WorkerBaseV1):
107118
"""A worker class that executes the model on a group of Spyre cores.
108119
"""
@@ -432,12 +443,6 @@ def load_model(self):
432443
logger.info("load model took %.3fs", load_model_total_t)
433444

434445
def _warmup_spyre_dynamic_size(self, special_token_ids):
435-
# this setting is required to mark a dimension of size 1 as dynamic
436-
# for pytorch >= 2.7.1 (needed to support batch size 1 for decodes)
437-
438-
from torch.fx.experimental import _config as config
439-
config.backed_size_oblivious = True
440-
441446
warmup_start_t = time.time()
442447

443448
# satisfy mypy
@@ -491,6 +496,21 @@ def _warmup_spyre_dynamic_size(self, special_token_ids):
491496
# one additional prefill to deploy the compiled program to the device,
492497
# the necessary operations are included in the graph and will be removed
493498
# after this execution
499+
500+
# update sampling_params here to ensure logits processing code is also
501+
# compiled during warmup
502+
deploy_req.sampling_params = SamplingParams(
503+
temperature=1.0,
504+
top_k=10,
505+
top_p=0.9,
506+
min_p=0.9,
507+
presence_penalty=0.5,
508+
frequency_penalty=0.5,
509+
repetition_penalty=1.2,
510+
max_tokens=4,
511+
min_tokens=1,
512+
logprobs=1,
513+
)
494514
scheduler_output = SchedulerOutput(
495515
scheduled_new_reqs=[deploy_req],
496516
scheduled_cached_reqs=CachedRequestData.make_empty(),
@@ -651,6 +671,7 @@ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
651671
num_decode_tokens, warmup_total_t, compile_cache_str)
652672
maybe_override_signals_handler()
653673

674+
@use_torch_fx_backed_size_oblivious()
654675
def _dynamic_warmup(
655676
self,
656677
requests: list[NewRequestData],

0 commit comments

Comments
 (0)