Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ repos:
types: [python]
pass_filenames: false
additional_dependencies: [regex]
# prevent use torch.cuda APIs
- id: check-torch-cuda-call
name: "Prevent new 'torch.cuda' APIs call"
entry: python tools/pre_commit/check_torch_cuda.py
language: python
types: [python]
additional_dependencies: [regex]
- id: check-pickle-imports
name: Prevent new pickle/cloudpickle imports
entry: python tools/pre_commit/check_pickle_imports.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def main():
# Clean up the GPU memory for the next test
del engine
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/rlhf_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def get_size(p: torch.Tensor) -> int:
s.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()


# Ray manages four GPUs.
Expand Down
2 changes: 1 addition & 1 deletion examples/offline_inference/rlhf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
socket.close()
del buffer
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

def report_device_id(self) -> str:
from vllm.platforms import current_platform
Expand Down
39 changes: 39 additions & 0 deletions tools/pre_commit/check_torch_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys

import regex as re

# --------------------------------------------------------------------------- #
# Regex: match `torch.cuda.xxx` but allow `torch.accelerator.xxx`
# --------------------------------------------------------------------------- #
_TORCH_CUDA_RE = re.compile(r"\btorch\.cuda\.empty_cache\b")


ALLOWED_FILES = {"tests/", "benchmarks/", "vllm/platforms/"}


def scan_file(path: str) -> int:
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f, 1):
if _TORCH_CUDA_RE.search(line):
print(
f"{path}:{i}: "
"\033[91merror:\033[0m " # red color
"Found torch.cuda API call"
)
return 1
return 0


def main():
returncode = 0
for filename in sys.argv[1:]:
if any(filename.startswith(prefix) for prefix in ALLOWED_FILES):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be made slightly faster by combining ALLOWED_FILES into a regex pattern and using re.match instead of a Python loop

continue
returncode |= scan_file(filename)
return returncode


if __name__ == "__main__":
sys.exit(main())
4 changes: 3 additions & 1 deletion vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def __call__(self, *args, **kwargs):
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.cuda.empty_cache", lambda: None))
stack.enter_context(
patch("torch.accelerator.empty_cache", lambda: None)
)

if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _maybe_pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
):
num_pad = 256 // weight.element_size()
weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

return weight

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def process_weights_after_loading(self, layer):
requires_grad=layer.w2_weight.requires_grad,
)

torch.cuda.empty_cache()
torch.accelerator.empty_cache()

def get_fused_moe_quant_config(
self, layer: torch.nn.Module
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ def _maybe_pad_fp8_weight(weight: torch.Tensor) -> torch.Tensor:
import torch.nn.functional as F

weight = F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad]
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
return weight


Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/bitsandbytes_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,7 @@ def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
6 changes: 3 additions & 3 deletions vllm/utils/mem_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def measure(self) -> None:
# rather than `torch.cuda.memory_reserved()` .
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
# when we call `torch.accelerator.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)

self.free_memory, self.total_memory = torch.cuda.mem_get_info()
Expand Down Expand Up @@ -201,7 +201,7 @@ def memory_profiling(
The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
torch.cuda.reset_peak_memory_stats()

result = MemoryProfilingResult()
Expand All @@ -215,7 +215,7 @@ def memory_profiling(
yield result

gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

result.after_profile.measure()

Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def capture_model(self) -> int:

start_time = time.perf_counter()
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]

with self.maybe_setup_dummy_loras(self.lora_config):
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def init_device(self):

# Now take memory snapshot after NCCL is initialized
gc.collect()
torch.cuda.empty_cache()
torch.accelerator.empty_cache()

# take current memory snapshot
self.init_snapshot = MemorySnapshot()
Expand Down Expand Up @@ -328,7 +328,7 @@ def determine_available_memory(self) -> int:
logger.info(msg)
return kv_cache_memory_bytes

torch.cuda.empty_cache()
torch.accelerator.empty_cache()
torch.cuda.reset_peak_memory_stats()

# Execute a forward pass with dummy inputs to profile the memory usage
Expand Down Expand Up @@ -519,7 +519,7 @@ def compile_or_warm_up_model(self) -> None:
# sampling related tensors of max possible shape to avoid memory
# fragmentation issue.
# NOTE: This is called after `capture_model` on purpose to prevent
# memory buffers from being cleared by `torch.cuda.empty_cache`.
# memory buffers from being cleared by `torch.accelerator.empty_cache`.
if get_pp_group().is_last_rank:
max_num_reqs = min(
self.scheduler_config.max_num_seqs,
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def determine_available_memory(self) -> int:
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.xpu.empty_cache()
torch.accelerator.empty_cache()
torch.xpu.reset_peak_memory_stats()

free_gpu_memory, total_gpu_memory = torch.xpu.mem_get_info()
Expand Down
Loading