Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ repos:
types: [python]
pass_filenames: false
additional_dependencies: [regex]
# forbid use torch.cuda APIs
- id: forbid-direct-torch-cuda
name: "Forbid direct 'torch.cuda' APIs"
entry: python tools/pre_commit/check_torch_cuda.py
language: python
types: [python]
pass_filenames: false
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should pass file names, otherwise this will run on all files every commit

Suggested change
pass_filenames: false

additional_dependencies: [regex]
- id: check-pickle-imports
name: Prevent new pickle/cloudpickle imports
entry: python tools/pre_commit/check_pickle_imports.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main():
# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/rlhf_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_size(p: torch.Tensor) -> int:
s.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()


# Ray manages four GPUs.
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/rlhf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

def report_device_id(self) -> str:
from vllm.platforms import current_platform
Expand Down
85 changes: 85 additions & 0 deletions tools/pre_commit/check_torch_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
import sys

import regex as re

# --------------------------------------------------------------------------- #
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_RE = re.compile(r"\btorch\.cuda\.empty_cache\b")


ALLOWED_FILES = {"tests/", "benchmarks/", "vllm/platforms/*"}


def is_allowed_file(current_file: str) -> bool:
return current_file in ALLOWED_FILES
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The check current_file in ALLOWED_FILES performs an exact match, but ALLOWED_FILES contains directory prefixes. For example, a file tests/models/test_llama.py will not be matched against tests/. This will cause the pre-commit hook to incorrectly flag files that should be allowed. You should use startswith to check if the file path is under one of the allowed directories. Also, vllm/platforms/* seems to intend to match all files in the directory, so it should probably be vllm/platforms/.

Suggested change
ALLOWED_FILES = {"tests/", "benchmarks/", "vllm/platforms/*"}
def is_allowed_file(current_file: str) -> bool:
return current_file in ALLOWED_FILES
ALLOWED_FILES = {"tests/", "benchmarks/", "vllm/platforms/"}
def is_allowed_file(current_file: str) -> bool:
return any(current_file.startswith(p) for p in ALLOWED_FILES)



def is_forbidden_torch_cuda_api(line: str) -> bool:
stripped = line.strip()
return bool(_TORCH_CUDA_RE.match(stripped))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

re.match only checks for a match at the beginning of the string. This will fail to detect forbidden API calls that are not at the start of a line (after stripping whitespace), for example x = torch.cuda.empty_cache(). You should use re.search to find a match anywhere in the line.

Suggested change
def is_forbidden_torch_cuda_api(line: str) -> bool:
stripped = line.strip()
return bool(_TORCH_CUDA_RE.match(stripped))
def is_forbidden_torch_cuda_api(line: str) -> bool:
return bool(_TORCH_CUDA_RE.search(line))



def parse_diff(diff: str) -> list[str]:
violations = []
current_file = None
current_lineno = None
skip_allowed_file = False

for line in diff.splitlines():
if line.startswith("+++ b/"):
current_file = line[6:]
skip_allowed_file = is_allowed_file(current_file)
elif skip_allowed_file:
continue
elif line.startswith("@@"):
match = re.search(r"\+(\d+)", line)
if match:
current_lineno = int(match.group(1)) - 1 # next "+ line" is here
elif line.startswith("+") and not line.startswith("++"):
current_lineno += 1
code_line = line[1:]
if is_forbidden_torch_cuda_api(code_line):
violations.append(
f"{current_file}:{current_lineno}: {code_line.strip()}"
)
return violations


def get_diff(diff_type: str) -> str:
if diff_type == "staged":
return subprocess.check_output(
["git", "diff", "--cached", "--unified=0"], text=True
)
elif diff_type == "unstaged":
return subprocess.check_output(["git", "diff", "--unified=0"], text=True)
else:
raise ValueError(f"Unknown diff_type: {diff_type}")


def main():
all_violations = []
for diff_type in ["staged", "unstaged"]:
try:
diff_output = get_diff(diff_type)
violations = parse_diff(diff_output)
all_violations.extend(violations)
except subprocess.CalledProcessError as e:
print(f"[{diff_type}] Git diff failed: {e}", file=sys.stderr)

if all_violations:
print(
"❌ Forbidden direct `torch.cuda.empty_cache` detected."
" ➤ Use `torch.accelerator.empty_cache` instead.\n"
)
for v in all_violations:
print(f"❌ {v}")
return 1
return 0


if __name__ == "__main__":
sys.exit(main())
4 changes: 3 additions & 1 deletion vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def __call__(self, *args, **kwargs):
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
stack.enter_context(
patch("torch.accelerator.empty_cache", lambda: None)
)

if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

return weight

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def process_weights_after_loading(self, layer):
requires_grad=layer.w2_weight.requires_grad,
)

torch.cuda.empty_cache()
torch.accelerator.empty_cache()

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
import torch.nn.functional as F

weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return weight


Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
6 changes: 3 additions & 3 deletions vllm/utils/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def measure(self) -> None:
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
# when we call `torch.accelerator.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)

self.free_memory, self.total_memory = torch.cuda.mem_get_info()
Expand Down Expand Up @@ -201,7 +201,7 @@ def memory_profiling(
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
torch.cuda.reset_peak_memory_stats()

result = MemoryProfilingResult()
Expand All @@ -215,7 +215,7 @@ def memory_profiling(
yield result

gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

result.after_profile.measure()

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def capture_model(self) -> int:

start_time = time.perf_counter()
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]

with self.maybe_setup_dummy_loras(self.lora_config):
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def init_device(self):

# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

# take current memory snapshot
self.init_snapshot = MemorySnapshot()
Expand Down Expand Up @@ -328,7 +328,7 @@ def determine_available_memory(self) -> int:
logger.info(msg)
return kv_cache_memory_bytes

torch.cuda.empty_cache()
torch.accelerator.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Execute a forward pass with dummy inputs to profile the memory usage
Expand Down Expand Up @@ -519,7 +519,7 @@ def compile_or_warm_up_model(self) -> None:
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
# memory buffers from being cleared by `torch.accelerator.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
Expand Down
Loading