|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import os |
15 | 16 | from typing import TYPE_CHECKING |
16 | 17 |
|
17 | 18 | import torch |
@@ -112,17 +113,39 @@ def pre_register_and_update( |
112 | 113 | if action.dest == "block_size": |
113 | 114 | action.choices = None # Override choices |
114 | 115 |
|
| 116 | + @classmethod |
| 117 | + def validate_and_setup_prerequisite(cls, vllm_config: VllmConfig) -> None: |
| 118 | + scheduler_config = vllm_config.scheduler_config |
| 119 | + if not scheduler_config.enable_chunked_prefill: |
| 120 | + raise ValueError( |
| 121 | + "RBLN does not officially support disabling chunked prefill. " |
| 122 | + "Please don't disable chunked prefill by yourself." |
| 123 | + ) |
| 124 | + |
| 125 | + parallel_config = vllm_config.parallel_config |
| 126 | + use_model_parallel = ( |
| 127 | + parallel_config.tensor_parallel_size > 1 |
| 128 | + or parallel_config.pipeline_parallel_size > 1 |
| 129 | + or parallel_config.data_parallel_size > 1 |
| 130 | + or parallel_config.enable_expert_parallel |
| 131 | + ) |
| 132 | + if use_model_parallel: |
| 133 | + if envs.VLLM_RBLN_PROFILER: |
| 134 | + raise RuntimeError( |
| 135 | + "RBLN_PROFILER is not supported when using vLLM model parallel " |
| 136 | + "(TP, DP, EP, or PP)." |
| 137 | + ) |
| 138 | + os.environ["RBLN_CTX_STANDALONE"] = "1" |
| 139 | + os.environ["RBLN_FORCE_CCL_ASYNC"] = "1" |
| 140 | + |
115 | 141 | @classmethod |
116 | 142 | def check_and_update_config(cls, vllm_config: VllmConfig) -> None: |
117 | 143 | model_config = vllm_config.model_config |
118 | 144 | parallel_config = vllm_config.parallel_config |
119 | 145 | scheduler_config = vllm_config.scheduler_config |
120 | 146 |
|
121 | 147 | if envs.VLLM_RBLN_USE_VLLM_MODEL: |
122 | | - assert scheduler_config.enable_chunked_prefill, ( |
123 | | - "RBLN does not officially support disabling chunked prefill. " |
124 | | - "Please don't disable chunked prefill by yourself." |
125 | | - ) |
| 148 | + cls.validate_and_setup_prerequisite(vllm_config) |
126 | 149 | if envs.VLLM_RBLN_ENFORCE_MODEL_FP32: |
127 | 150 | logger.info("original model_config.dtype = %s", model_config.dtype) |
128 | 151 | if model_config.dtype == torch.bfloat16: |
|
0 commit comments