diff --git a/vllm_omni/entrypoints/omni.py b/vllm_omni/entrypoints/omni.py index 97357dc3b3..eec6537a2a 100644 --- a/vllm_omni/entrypoints/omni.py +++ b/vllm_omni/entrypoints/omni.py @@ -42,6 +42,9 @@ resolve_model_config_path, ) from vllm_omni.inputs.data import OmniDiffusionSamplingParams, OmniPromptType, OmniSamplingParams +from vllm_omni.model_executor.model_loader.weight_utils import ( + download_weights_from_hf_specific, +) from vllm_omni.outputs import OmniRequestOutput logger = init_logger(__name__) @@ -68,14 +71,24 @@ def _dummy_snapshot_download(model_id): def omni_snapshot_download(model_id) -> str: + # If it's already a local path, just return it + if os.path.exists(model_id): + return model_id # TODO: this is just a workaround for quickly use modelscope, we should support # modelscope in weight loading feature instead of using `snapshot_download` if os.environ.get("VLLM_USE_MODELSCOPE", False): from modelscope.hub.snapshot_download import snapshot_download return snapshot_download(model_id) - else: - return _dummy_snapshot_download(model_id) + # For other cases (Hugging Face), perform a real download to ensure all + # necessary files (including *.pt for audio/diffusion) are available locally + # before stage workers are spawned. This prevents initialization timeouts. + return download_weights_from_hf_specific( + model_name_or_path=model_id, + cache_dir=None, + allow_patterns=["**/*.json", "**/*.bin", "**/*.safetensors", "**/*.pt", "**/*.txt", "**/*.model", "**/*.yaml"], + require_all=True, + ) class OmniBase: diff --git a/vllm_omni/model_executor/model_loader/weight_utils.py b/vllm_omni/model_executor/model_loader/weight_utils.py index 7432ad9a2a..d147269d66 100644 --- a/vllm_omni/model_executor/model_loader/weight_utils.py +++ b/vllm_omni/model_executor/model_loader/weight_utils.py @@ -20,6 +20,7 @@ def download_weights_from_hf_specific( allow_patterns: list[str], revision: str | None = None, ignore_patterns: str | list[str] | None = None, + require_all: bool = False, ) -> str: """Download model weights from Hugging Face Hub. Users can specify the allow_patterns to download only the necessary weights. @@ -35,6 +36,9 @@ def download_weights_from_hf_specific( ignore_patterns (Optional[Union[str, list[str]]]): The patterns to filter out the weight files. Files matched by any of the patterns will be ignored. + require_all (bool): If True, will iterate through and download files + matching all patterns in allow_patterns. If False, will stop after + the first pattern that matches any files. Returns: str: The path to the downloaded model weights. @@ -59,8 +63,8 @@ def download_weights_from_hf_specific( **download_kwargs, ) # If we have downloaded weights for this allow_pattern, - # we don't need to check the rest. - if any(Path(hf_folder).glob(allow_pattern)): + # we don't need to check the rest, unless require_all is set. + if not require_all and any(Path(hf_folder).glob(allow_pattern)): break time_taken = time.perf_counter() - start_time if time_taken > 0.5: