Skip to content

Commit 4d5f657

Browse files
authored
integration-vllm-test (#2258)
stack-info: PR: #2258, branch: drisspg/stack/58
1 parent 0b33f12 commit 4d5f657

File tree

1 file changed

+252
-0
lines changed

1 file changed

+252
-0
lines changed

test/integration/test_vllm.py

Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import importlib.util
8+
import os
9+
import random
10+
import shutil
11+
from pathlib import Path
12+
from typing import List
13+
14+
import numpy as np
15+
import pytest
16+
import torch
17+
18+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_7
19+
20+
if not TORCH_VERSION_AT_LEAST_2_7:
21+
pytest.skip("Requires PyTorch 2.7 or higher", allow_module_level=True)
22+
23+
24+
VLLM_AVAILABLE = importlib.util.find_spec("vllm") is not None
25+
TRANSFORMERS_AVAILABLE = importlib.util.find_spec("transformers") is not None
26+
27+
if not VLLM_AVAILABLE:
28+
pytest.skip("vLLM not installed", allow_module_level=True)
29+
30+
if not TRANSFORMERS_AVAILABLE:
31+
pytest.skip("transformers not installed", allow_module_level=True)
32+
33+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
34+
from vllm import LLM, SamplingParams
35+
36+
from torchao.quantization.granularity import PerRow, PerTensor
37+
from torchao.quantization.quant_api import (
38+
CutlassInt4PackedLayout,
39+
Float8DynamicActivationFloat8WeightConfig,
40+
Int8DynamicActivationInt4WeightConfig,
41+
Int8WeightOnlyConfig,
42+
)
43+
44+
45+
def get_tests() -> List[TorchAoConfig]:
46+
"""Get all the tests based off of device info"""
47+
48+
# Helper objects for granularity
49+
per_tensor = PerTensor()
50+
per_row = PerRow()
51+
52+
BASE_TESTS = [TorchAoConfig(Int8WeightOnlyConfig())]
53+
SM89_TESTS = [
54+
TorchAoConfig(
55+
Float8DynamicActivationFloat8WeightConfig(granularity=per_tensor)
56+
),
57+
TorchAoConfig(Float8DynamicActivationFloat8WeightConfig(granularity=per_row)),
58+
]
59+
SM90_ONLY_TESTS = [
60+
TorchAoConfig(
61+
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
62+
)
63+
]
64+
SM100_TESTS = [
65+
# TorchAoConfig(MXFPInferenceConfig())
66+
] # Failing for : https://github.com/pytorch/ao/issues/2239
67+
68+
# Check CUDA availability first
69+
if not torch.cuda.is_available():
70+
return [] # No CUDA, no tests
71+
72+
major, minor = torch.cuda.get_device_capability()
73+
74+
# Build test list based on compute capability
75+
all_tests = []
76+
77+
# Always include base tests if we have CUDA
78+
all_tests.extend(BASE_TESTS)
79+
80+
# Add SM89+ tests
81+
if major > 8 or (major == 8 and minor >= 9):
82+
all_tests.extend(SM89_TESTS)
83+
84+
# Add SM100+ tests
85+
if major >= 10:
86+
all_tests.extend(SM100_TESTS)
87+
88+
# Only work for sm 90
89+
if major == 9:
90+
all_tests.extend(SM90_ONLY_TESTS)
91+
92+
return all_tests
93+
94+
95+
class TestVLLMIntegration:
96+
"""Integration tests for vLLM with quantized models."""
97+
98+
@classmethod
99+
def setup_class(cls):
100+
"""Set up test environment."""
101+
# Set seeds for reproducibility
102+
cls.set_seed(42)
103+
104+
# See https://github.com/pytorch/ao/issues/2239 for details
105+
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
106+
# For Small testing this makes it faster
107+
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
108+
109+
@classmethod
110+
def teardown_class(cls):
111+
"""Clean up after all tests."""
112+
torch.cuda.empty_cache()
113+
import gc
114+
115+
gc.collect()
116+
117+
def setup_method(self, method):
118+
"""Clean up before each test method."""
119+
torch.cuda.empty_cache()
120+
import gc
121+
122+
gc.collect()
123+
124+
def teardown_method(self, method):
125+
"""Clean up after each test method."""
126+
torch.cuda.empty_cache()
127+
import gc
128+
129+
gc.collect()
130+
131+
@staticmethod
132+
def set_seed(seed):
133+
"""Set random seeds for reproducibility."""
134+
random.seed(seed)
135+
np.random.seed(seed)
136+
torch.manual_seed(seed)
137+
torch.cuda.manual_seed_all(seed)
138+
139+
def quantize_and_save_model(
140+
self,
141+
model_name: str,
142+
quantization_config: TorchAoConfig,
143+
output_dir: Path,
144+
):
145+
"""Quantize a model and save it to disk."""
146+
# Load and quantize model
147+
quantized_model = AutoModelForCausalLM.from_pretrained(
148+
model_name,
149+
torch_dtype="bfloat16",
150+
device_map="cuda",
151+
quantization_config=quantization_config,
152+
)
153+
tokenizer = AutoTokenizer.from_pretrained(model_name)
154+
# Save quantized model
155+
quantized_model.save_pretrained(output_dir, safe_serialization=False)
156+
tokenizer.save_pretrained(output_dir)
157+
158+
# Clean up to free memory
159+
del quantized_model
160+
torch.cuda.empty_cache()
161+
162+
return output_dir
163+
164+
def cleanup_model_directory(self, model_path: Path):
165+
"""Clean up the model directory safely."""
166+
try:
167+
if model_path.exists() and model_path.is_dir():
168+
shutil.rmtree(model_path)
169+
except (OSError, PermissionError) as e:
170+
# Log the error but don't fail the test
171+
print(f"Warning: Failed to clean up {model_path}: {e}")
172+
173+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
174+
@pytest.mark.skipif(not VLLM_AVAILABLE, reason="vLLM not installed")
175+
@pytest.mark.parametrize(
176+
"quantization_config", get_tests(), ids=lambda config: f"{config.quant_type}"
177+
)
178+
@pytest.mark.parametrize("compile", [True, False])
179+
@pytest.mark.parametrize(
180+
"tp_size", [1, 2] if torch.cuda.device_count() > 1 else [1]
181+
)
182+
def test_vllm_smoke_test(self, tmp_path, quantization_config, compile, tp_size):
183+
"""Test vLLM generation with quantized models."""
184+
# Skip per_row tests if not supported
185+
torch._dynamo.reset()
186+
187+
# Use a small model for testing
188+
base_model = "facebook/opt-125m"
189+
190+
# Create a descriptive name for the output directory
191+
config_name = str(quantization_config).replace("/", "_").replace(" ", "_")[:50]
192+
output_dir = tmp_path / f"{config_name}-opt-125m"
193+
194+
llm = None
195+
quantized_model_path = None
196+
197+
try:
198+
# Quantize the model
199+
quantized_model_path = self.quantize_and_save_model(
200+
base_model, quantization_config, output_dir
201+
)
202+
203+
# Test generation with vLLM
204+
sampling_params = SamplingParams(
205+
temperature=0.8,
206+
top_p=0.95,
207+
seed=42,
208+
max_tokens=16, # Small for testing
209+
)
210+
211+
# Create LLM instance
212+
llm = LLM(
213+
model=str(quantized_model_path),
214+
tensor_parallel_size=tp_size,
215+
enforce_eager=not compile,
216+
dtype="bfloat16",
217+
num_gpu_blocks_override=128,
218+
)
219+
220+
# Test prompts
221+
prompts = [
222+
"Hello, my name is",
223+
"The capital of France is",
224+
]
225+
226+
# Generate outputs
227+
outputs = llm.generate(prompts, sampling_params)
228+
229+
# Verify outputs
230+
assert len(outputs) == len(prompts)
231+
for output in outputs:
232+
assert output.prompt in prompts
233+
assert len(output.outputs) > 0
234+
generated_text = output.outputs[0].text
235+
assert isinstance(generated_text, str)
236+
assert len(generated_text) > 0
237+
238+
finally:
239+
# Clean up resources
240+
if llm is not None:
241+
del llm
242+
243+
# Clean up CUDA memory
244+
torch.cuda.empty_cache()
245+
246+
# Clean up the saved model directory
247+
if quantized_model_path is not None:
248+
self.cleanup_model_directory(quantized_model_path)
249+
250+
251+
if __name__ == "__main__":
252+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)