Skip to content

Commit c2dc756

Browse files
core(platform): check prerequisite for parallelism (#393)
* core(platform): check prerequisite for parallelism * Set ccl env
1 parent 31ed5d0 commit c2dc756

2 files changed

Lines changed: 34 additions & 9 deletions

File tree

vllm_rbln/platform.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from typing import TYPE_CHECKING
1617

1718
import torch
@@ -112,17 +113,39 @@ def pre_register_and_update(
112113
if action.dest == "block_size":
113114
action.choices = None # Override choices
114115

116+
@classmethod
117+
def validate_and_setup_prerequisite(cls, vllm_config: VllmConfig) -> None:
118+
scheduler_config = vllm_config.scheduler_config
119+
if not scheduler_config.enable_chunked_prefill:
120+
raise ValueError(
121+
"RBLN does not officially support disabling chunked prefill. "
122+
"Please don't disable chunked prefill by yourself."
123+
)
124+
125+
parallel_config = vllm_config.parallel_config
126+
use_model_parallel = (
127+
parallel_config.tensor_parallel_size > 1
128+
or parallel_config.pipeline_parallel_size > 1
129+
or parallel_config.data_parallel_size > 1
130+
or parallel_config.enable_expert_parallel
131+
)
132+
if use_model_parallel:
133+
if envs.VLLM_RBLN_PROFILER:
134+
raise RuntimeError(
135+
"RBLN_PROFILER is not supported when using vLLM model parallel "
136+
"(TP, DP, EP, or PP)."
137+
)
138+
os.environ["RBLN_CTX_STANDALONE"] = "1"
139+
os.environ["RBLN_FORCE_CCL_ASYNC"] = "1"
140+
115141
@classmethod
116142
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
117143
model_config = vllm_config.model_config
118144
parallel_config = vllm_config.parallel_config
119145
scheduler_config = vllm_config.scheduler_config
120146

121147
if envs.VLLM_RBLN_USE_VLLM_MODEL:
122-
assert scheduler_config.enable_chunked_prefill, (
123-
"RBLN does not officially support disabling chunked prefill. "
124-
"Please don't disable chunked prefill by yourself."
125-
)
148+
cls.validate_and_setup_prerequisite(vllm_config)
126149
if envs.VLLM_RBLN_ENFORCE_MODEL_FP32:
127150
logger.info("original model_config.dtype = %s", model_config.dtype)
128151
if model_config.dtype == torch.bfloat16:

vllm_rbln/rbln_envs.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,6 @@ def get_decode_batch_bucket_manual_buckets() -> list[int]:
226226
"VLLM_RBLN_NUMA": (
227227
lambda: os.environ.get("VLLM_RBLN_NUMA", "True").lower() in ("true", "1")
228228
),
229-
"VLLM_RBLN_USE_CUSTOM_KERNEL": (
230-
lambda: (
231-
os.environ.get("RBLN_USE_CUSTOM_KERNEL", "False").lower() in ("true", "1")
232-
)
233-
),
234229
"VLLM_RBLN_SORT_BATCH": (
235230
lambda: os.environ.get("VLLM_RBLN_SORT_BATCH", "False").lower() in ("true", "1")
236231
),
@@ -254,6 +249,13 @@ def get_decode_batch_bucket_manual_buckets() -> list[int]:
254249
),
255250
# Decode batch bucket manual buckets
256251
"VLLM_RBLN_DECODE_BATCH_BUCKET_MANUAL_BUCKETS": get_decode_batch_bucket_manual_buckets, # noqa E501
252+
"VLLM_RBLN_USE_CUSTOM_KERNEL": (
253+
lambda: os.environ.get("RBLN_USE_CUSTOM_KERNEL", "False").lower()
254+
in ("true", "1")
255+
),
256+
"VLLM_RBLN_PROFILER": (
257+
lambda: os.environ.get("RBLN_PROFILER", "False").lower() in ("true", "1")
258+
),
257259
}
258260

259261

0 commit comments

Comments
 (0)