|
6 | 6 |
|
7 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 |
8 | 8 |
|
9 | | -import os |
10 | 9 | # Unless required by applicable law or agreed to in writing, software |
11 | 10 | # distributed under the License is distributed on an "AS IS" BASIS, |
12 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
29 | 28 | from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, |
30 | 29 | KVCacheSpec) |
31 | 30 | from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput |
| 31 | +from vllm.v1.sample.sampler import Sampler |
32 | 32 | from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch |
33 | 33 | from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin |
34 | 34 |
|
| 35 | +import vllm_rbln.rbln_envs as envs |
35 | 36 | from vllm_rbln.logger import init_logger |
36 | 37 | from vllm_rbln.model_executor.model_loader.rbln_model_loader import ( |
37 | 38 | get_optimum_model) |
38 | 39 | from vllm_rbln.model_executor.models.optimum import (ModelInputForRBLN, |
39 | 40 | RBLNOptimumDictTableMixin) |
40 | | -from vllm_rbln.v1.sample.sampler import Sampler |
| 41 | +from vllm_rbln.v1.sample.sampler import WARM_UP_CONFIGS |
41 | 42 | from vllm_rbln.v1.sample.sampler import Sampler as RBLNSampler |
42 | 43 | from vllm_rbln.v1.worker.multimodal import RBLNOptimumMultiModalKwargs |
43 | 44 |
|
44 | 45 | logger = init_logger(__name__) |
45 | 46 |
|
46 | 47 |
|
47 | | -def _use_rbln_sampler() -> bool: |
48 | | - """Check if RBLN sampler should be used based on environment variable.""" |
49 | | - TRUTHY_VALUES = frozenset({"1", "true", "yes", "on"}) |
50 | | - return os.environ.get("VLLM_RBLN_SAMPLER", |
51 | | - "").strip().lower() in TRUTHY_VALUES |
52 | | - |
53 | | - |
54 | 48 | class RBLNOptimumModelRunner(LoRAModelRunnerMixin): |
55 | 49 |
|
56 | 50 | def __init__(self, vllm_config: VllmConfig, device: torch.device): |
@@ -101,21 +95,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device): |
101 | 95 | self.mm_registry = MULTIMODAL_REGISTRY |
102 | 96 | self.uses_mrope = model_config.uses_mrope |
103 | 97 |
|
104 | | - # encoder_compute_budget, encoder_cache_size = compute_encoder_budget( |
105 | | - # model_config=model_config, |
106 | | - # scheduler_config=scheduler_config, |
107 | | - # mm_registry=self.mm_registry, |
108 | | - # ) |
109 | | - # self.max_num_encoder_input_tokens = encoder_compute_budget |
110 | | - # self.encoder_cache_size = encoder_cache_size |
111 | | - |
112 | 98 | # Sampler |
113 | | - use_rbln_sampler = _use_rbln_sampler() |
114 | | - logger.info("Using RBLN sampler: %s", use_rbln_sampler) |
115 | | - |
116 | | - sampler = RBLNSampler() if use_rbln_sampler else Sampler() |
117 | | - |
118 | | - if use_rbln_sampler: |
| 99 | + self.use_rbln_sampler = envs.RBLN_SAMPLER |
| 100 | + logger.info("Using RBLN sampler: %s", self.use_rbln_sampler) |
| 101 | + sampler = RBLNSampler() if self.use_rbln_sampler else Sampler() |
| 102 | + if self.use_rbln_sampler: |
119 | 103 | # Use torch.compile for optimized RBLN sampler |
120 | 104 | sampler = torch.compile(sampler, dynamic=False, fullgraph=False) |
121 | 105 |
|
@@ -622,3 +606,113 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: |
622 | 606 |
|
623 | 607 | def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: |
624 | 608 | pass |
| 609 | + |
| 610 | + def dummy_sampler_run(self): |
| 611 | + if not self.use_rbln_sampler: |
| 612 | + return |
| 613 | + |
| 614 | + def set_sampling_tensors(input_batch, **params): |
| 615 | + input_batch.temperature_cpu_tensor.fill_(params["temperature"]) |
| 616 | + input_batch.temperature.fill_(params["temperature"]) |
| 617 | + |
| 618 | + optional_keys = [ |
| 619 | + ("top_p", input_batch.top_p_cpu_tensor, input_batch.top_p), |
| 620 | + ("top_k", input_batch.top_k_cpu_tensor, input_batch.top_k), |
| 621 | + ("min_p", input_batch.min_p_cpu_tensor, input_batch.min_p), |
| 622 | + ("frequency_penalties", |
| 623 | + input_batch.frequency_penalties_cpu_tensor, |
| 624 | + input_batch.frequency_penalties), |
| 625 | + ("presence_penalties", |
| 626 | + input_batch.presence_penalties_cpu_tensor, |
| 627 | + input_batch.presence_penalties), |
| 628 | + ("repetition_penalties", |
| 629 | + input_batch.repetition_penalties_cpu_tensor, |
| 630 | + input_batch.repetition_penalties), |
| 631 | + ] |
| 632 | + |
| 633 | + for key, cpu_tensor, dev_tensor in optional_keys: |
| 634 | + val = params.get(key) |
| 635 | + if val is not None: |
| 636 | + cpu_tensor.fill_(val) |
| 637 | + dev_tensor.fill_(val) |
| 638 | + |
| 639 | + def populate_reqs(input_batch, base_config, batch_size): |
| 640 | + for i in range(batch_size): |
| 641 | + req_id = f"{base_config['name']}_req_{i}" |
| 642 | + input_batch._req_ids.append(req_id) |
| 643 | + input_batch.req_id_to_index[req_id] = i |
| 644 | + |
| 645 | + if base_config["all_greedy"]: |
| 646 | + input_batch.greedy_reqs.add(req_id) |
| 647 | + elif base_config["all_random"]: |
| 648 | + input_batch.random_reqs.add(req_id) |
| 649 | + |
| 650 | + for attr, req_set in [ |
| 651 | + ("top_p", input_batch.top_p_reqs), |
| 652 | + ("top_k", input_batch.top_k_reqs), |
| 653 | + ("frequency_penalties", |
| 654 | + input_batch.frequency_penalties_reqs), |
| 655 | + ("repetition_penalties", |
| 656 | + input_batch.repetition_penalties_reqs), |
| 657 | + ("presence_penalties", |
| 658 | + input_batch.presence_penalties_reqs), |
| 659 | + ]: |
| 660 | + if base_config.get(attr) is not None: |
| 661 | + req_set.add(req_id) |
| 662 | + |
| 663 | + def clear_reqs(input_batch): |
| 664 | + input_batch._req_ids.clear() |
| 665 | + input_batch.req_id_to_index.clear() |
| 666 | + input_batch.greedy_reqs.clear() |
| 667 | + input_batch.random_reqs.clear() |
| 668 | + input_batch.top_p_reqs.clear() |
| 669 | + input_batch.top_k_reqs.clear() |
| 670 | + input_batch.frequency_penalties_reqs.clear() |
| 671 | + input_batch.repetition_penalties_reqs.clear() |
| 672 | + input_batch.presence_penalties_reqs.clear() |
| 673 | + |
| 674 | + def dummy_run_batches(base_config): |
| 675 | + for batch_size in range(1, self.input_batch.max_num_reqs + 1): |
| 676 | + input_batch = self.input_batch |
| 677 | + populate_reqs(input_batch, base_config, batch_size) |
| 678 | + |
| 679 | + metadata = input_batch._make_sampling_metadata() |
| 680 | + metadata.no_penalties = base_config["no_penalties"] |
| 681 | + metadata.all_greedy = base_config["all_greedy"] |
| 682 | + metadata.all_random = base_config["all_random"] |
| 683 | + |
| 684 | + if (not metadata.no_penalties |
| 685 | + and metadata.prompt_token_ids is None): |
| 686 | + metadata.prompt_token_ids = torch.zeros((batch_size, 1), |
| 687 | + dtype=torch.long, |
| 688 | + device="cpu") |
| 689 | + |
| 690 | + logger.info( |
| 691 | + "Running dummy compile with batch_size=%d, vocab_size=%d", |
| 692 | + batch_size, input_batch.vocab_size) |
| 693 | + logger.info("Sampling metadata: %s", metadata) |
| 694 | + |
| 695 | + with torch.inference_mode(): |
| 696 | + empty_logits = torch.empty(batch_size, |
| 697 | + input_batch.vocab_size, |
| 698 | + dtype=torch.float32) |
| 699 | + _ = self.sampler(logits=empty_logits, |
| 700 | + sampling_metadata=metadata) |
| 701 | + |
| 702 | + clear_reqs(input_batch) |
| 703 | + |
| 704 | + for config in WARM_UP_CONFIGS: |
| 705 | + logger.info("Running dummy sampler config: %s", config["name"]) |
| 706 | + |
| 707 | + set_sampling_tensors( |
| 708 | + self.input_batch, |
| 709 | + temperature=config["temperature"], |
| 710 | + top_p=config.get("top_p"), |
| 711 | + top_k=config.get("top_k"), |
| 712 | + min_p=config.get("min_p"), |
| 713 | + frequency_penalties=config.get("frequency_penalties"), |
| 714 | + repetition_penalties=config.get("repetition_penalties"), |
| 715 | + presence_penalties=config.get("presence_penalties"), |
| 716 | + ) |
| 717 | + |
| 718 | + dummy_run_batches(config) |
0 commit comments