Skip to content

Commit e319150

Browse files
authored
[Bugfix] Fix for builtins (forward fix of pytorch/177558) (vllm-project#37234)
Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
1 parent 29e4870 commit e319150

2 files changed

Lines changed: 44 additions & 1 deletion

File tree

tools/pre_commit/check_forbidden_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class ForbiddenImport:
3131
"vllm/transformers_utils/config.py",
3232
"vllm/model_executor/models/registry.py",
3333
"vllm/compilation/caching.py",
34+
"vllm/env_override.py",
3435
"vllm/compilation/piecewise_backend.py",
3536
"vllm/distributed/utils.py",
3637
"vllm/distributed/parallel_state.py",

vllm/env_override.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _maybe_set_cuda_compatibility_path():
8787
import torch
8888

8989
from vllm.logger import init_logger
90-
from vllm.utils.torch_utils import is_torch_equal
90+
from vllm.utils.torch_utils import is_torch_equal, is_torch_equal_or_newer
9191

9292
logger = init_logger(__name__)
9393

@@ -490,3 +490,45 @@ def _patch_get_raw_stream_if_needed():
490490

491491
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
492492
GraphLowering._update_scheduler = _update_scheduler_patched
493+
494+
# ===================================================
495+
# torch <2.12 GraphCaptureOutput.get_runtime_env monkeypatch
496+
# ===================================================
497+
# PyTorch's AOT compile path omits builtins from used_globals, causing
498+
# 'Missing required external references' errors for refs like 'type'.
499+
# (which happens in transformers code)
500+
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
501+
# and can be removed once torch >=2.12 is the minimum supported version.
502+
503+
if not is_torch_equal_or_newer("2.12.0"):
504+
import builtins as _builtins
505+
import pickle
506+
507+
from torch._dynamo.convert_frame import GraphCaptureOutput
508+
509+
_original_get_runtime_env = GraphCaptureOutput.get_runtime_env
510+
511+
def _safe_builtins_dict(builtins_dict: dict) -> dict:
512+
"""Filter a builtins dict to only picklable entries for serialization."""
513+
result = {}
514+
for k, v in builtins_dict.items():
515+
try:
516+
pickle.dumps(v)
517+
result[k] = v
518+
except Exception:
519+
pass
520+
return result
521+
522+
def _patched_get_runtime_env(self): # type: ignore[no-untyped-def]
523+
runtime_env = _original_get_runtime_env(self)
524+
for ref in runtime_env.external_refs:
525+
if ref not in runtime_env.used_globals:
526+
if ref.startswith("__builtins_dict__") and ref in self.f_globals:
527+
runtime_env.used_globals[ref] = _safe_builtins_dict(
528+
self.f_globals[ref]
529+
)
530+
elif hasattr(_builtins, ref):
531+
runtime_env.used_globals[ref] = getattr(_builtins, ref)
532+
return runtime_env
533+
534+
GraphCaptureOutput.get_runtime_env = _patched_get_runtime_env

0 commit comments

Comments
 (0)