Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 0e13607

Browse files
Zhenzhong1VincyZhangchangwangss
authored
[vLLM] Support vLLM CPU backend and provide QBits acceleration (#1551)
Co-authored-by: VincyZhang <[email protected]> Co-authored-by: Wang, Chang <[email protected]>
1 parent 93b12e9 commit 0e13607

File tree

6 files changed

+292
-57
lines changed

6 files changed

+292
-57
lines changed

examples/vllm/README.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# vLLM Acceleration with ITREX
2+
3+
Intel extension for transformers(ITREX) integrates the vLLM CPU backend and offers optional [QBits Module](../../docs/qbits.md) to accelerate the vLLM inference on CPUs.
4+
5+
## Installation Methods
6+
7+
1. vLLM Installation with CPU: Install vLLM from source code following the instructions provided [here](https://docs.vllm.ai/en/latest/getting_started/cpu-installation.html).
8+
9+
2. ITREX Installation: Install the ITREX following the [link](../../docs/get_started.md)
10+
11+
3. Dependencies: Install some additional dependencies that may be used. The dependencies are listed in the current directory.
12+
13+
Note: torch==2.3.0+cpu is required and vllm==0.4.2+cpu is validated.
14+
15+
## Usage Example
16+
17+
ITREX provides a script that demonstrates the vLLM inference acceleration. Run it with the following command:
18+
```bash
19+
numactl -m 0 -C 0-55 python vllm_acceleration_example.py --model_path=/home/model/chatglm2-6b --prompt=你好
20+
```
21+
22+
## Supported and Validated Models
23+
All models listed in the [vLLM Supported Models](https://docs.vllm.ai/en/latest/models/supported_models.html) can be accelerated theoretically.
24+
25+
We have validated the majority of existing models using vLLM==0.4.2+cpu:
26+
* [THUDM/chatglm2-6b](https://hf-mirror.com/THUDM/chatglm2-6b)
27+
* [meta-llama/Llama-2-7b-chat-hf](https://hf-mirror.com/meta-llama/Llama-2-7b-chat-hf)
28+
* [baichuan-inc/Baichuan2-7B-Chat](https://hf-mirror.com/baichuan-inc/Baichuan2-7B-Chat)
29+
* [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b)
30+
* [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
31+
* [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
32+
* [microsoft/phi-2](https://huggingface.co/microsoft/phi-2)
33+
* [Qwen/CodeQwen1.5-7B-Chat](https://huggingface.co/Qwen/CodeQwen1.5-7B-Chat)
34+
35+
If you encounter any problems, please let us know.

examples/vllm/requirement.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
accelerate
2+
datasets
3+
peft
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2024 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import time
17+
import os
18+
from vllm import LLM, SamplingParams
19+
from typing import List, Optional
20+
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, RtnConfig
21+
from transformers import AutoTokenizer
22+
23+
24+
def main(args_in: Optional[List[str]] = None) -> None:
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument("--model_path", type=str, help="Model name: String", required=True)
27+
parser.add_argument(
28+
"-p",
29+
"--prompt",
30+
type=str,
31+
help="Prompt to start generation with: String (default: empty)",
32+
default="Once upon a time",
33+
)
34+
parser.add_argument("--benchmark", action="store_true")
35+
parser.add_argument("--use_neural_speed", action="store_true")
36+
args = parser.parse_args(args_in)
37+
print(args)
38+
39+
if args.benchmark:
40+
if args.use_neural_speed:
41+
os.environ["NEURAL_SPEED_VERBOSE"] = "1"
42+
woq_config = RtnConfig(bits=4, weight_dtype="int4", compute_dtype="int8", scale_dtype="bf16")
43+
model_with_ns = AutoModelForCausalLM.from_pretrained(args.model_path, quantization_config=woq_config)
44+
45+
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
46+
inputs = tokenizer(args.prompt, return_tensors="pt").input_ids
47+
48+
T5 = time.time()
49+
output = model_with_ns.generate(inputs, max_new_tokens=32)
50+
T6 = time.time()
51+
print("neural speed output = ", output)
52+
53+
llm = LLM(model=args.model_path, trust_remote_code=True)
54+
sampling_params = SamplingParams(max_tokens=32)
55+
T1 = time.time()
56+
original_outputs = llm.generate(args.prompt, sampling_params) # Generate texts from the prompts.
57+
T2 = time.time()
58+
vllm_latency = (T2 - T1) * 1000
59+
60+
model = AutoModelForCausalLM.from_pretrained(args.model_path, use_vllm=True)
61+
T3 = time.time()
62+
optimized_output = model.generate(args.prompt, sampling_params)
63+
T4 = time.time()
64+
qbits_latency = (T4 - T3) * 1000
65+
66+
print("original outputs = ", original_outputs)
67+
print("input_tokens_length = ", len(original_outputs[0].prompt_token_ids))
68+
print("output_tokens_length = ", len(original_outputs[0].outputs[0].token_ids))
69+
70+
print("optimized outputs = ", optimized_output)
71+
print("input_tokens_length = ", len(optimized_output[0].prompt_token_ids))
72+
print("output_tokens_length = ", len(optimized_output[0].outputs[0].token_ids))
73+
74+
print('The qbits optimized generate:%.2f ms' % qbits_latency)
75+
print('The original vLLM generate:%.2f ms' % vllm_latency)
76+
77+
return
78+
79+
model = AutoModelForCausalLM.from_pretrained(args.model_path, use_vllm=True)
80+
output = model.generate(args.prompt)
81+
print(output)
82+
83+
84+
if __name__ == "__main__":
85+
main()

intel_extension_for_transformers/transformers/llm/quantization/nn/modules.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717

18+
import os
1819
import torch
1920
from ..utils import DTYPE_BITS_MAPPING
2021
from functools import reduce
@@ -23,19 +24,19 @@
2324
from peft.tuners.lora import LoraLayer, LoraModel
2425
from peft.utils.other import transpose
2526
from intel_extension_for_transformers.transformers.llm.quantization.autograd import (
26-
matmul_kbit,
27-
)
27+
matmul_kbit, )
2828
import intel_extension_for_transformers.qbits as qbits # pylint: disable=E0611, E0401
2929

3030

3131
class DropoutQBits_(torch.autograd.Function):
32+
3233
@staticmethod
3334
def forward(ctx, input, probability):
3435
mask = qbits.dropout_fwd(input, probability)
3536
if any(ctx.needs_input_grad[:1]):
36-
ctx.tensors = (mask,)
37+
ctx.tensors = (mask, )
3738
else:
38-
ctx.tensors = (None,)
39+
ctx.tensors = (None, )
3940
return input
4041

4142
@staticmethod
@@ -51,6 +52,7 @@ def backward(ctx, grad_output):
5152

5253

5354
class DropoutQBits(torch.nn.Module):
55+
5456
def __init__(self, p=0.0):
5557
super().__init__()
5658
self.p = p
@@ -63,6 +65,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
6365

6466

6567
class ParamsQBits(torch.nn.Parameter):
68+
6669
def __new__(
6770
cls,
6871
data=None,
@@ -87,6 +90,7 @@ def __new__(
8790

8891

8992
class QuantizedLinearQBits(torch.nn.Linear):
93+
9094
def __init__(
9195
self,
9296
input_features,
@@ -156,6 +160,9 @@ def forward(self, x: torch.Tensor):
156160
shape[-1] = self.out_features
157161
out = out.view(shape)
158162

163+
if os.environ.get("backend", None) == "use_vllm":
164+
return out, None
165+
159166
return out
160167

161168
def set_fp_weights_bias(self, weight_data, bias=None):
@@ -264,33 +271,24 @@ def quant_weight_w_scale(self, weight, scale, zp, group_size=-1):
264271
if zp is not None:
265272
zp = zp.to(device)
266273
if group_size == -1:
267-
return (
268-
weight.div_(scale).round_()
269-
if zp is None
270-
else weight.div_(scale).add_(zp).round_()
271-
)
274+
return (weight.div_(scale).round_() if zp is None else weight.div_(scale).add_(zp).round_())
272275
int_weight = torch.zeros(weight.shape).to(device)
273276
leng = weight.shape[1] // group_size
274277
tail_flag = False if weight.shape[1] % group_size == 0 else True
275278
for i in range(leng):
276-
int_weight_tmp = weight[:, i * group_size : (i + 1) * group_size].div_(
277-
scale[:, i].unsqueeze(1)
278-
)
279+
int_weight_tmp = weight[:, i * group_size:(i + 1) * group_size].div_(scale[:, i].unsqueeze(1))
279280
if zp is not None:
280281
int_weight_tmp.add_(zp[:, i].unsqueeze(1))
281-
int_weight[:, i * group_size : (i + 1) * group_size].copy_(
282-
int_weight_tmp.round_()
283-
)
282+
int_weight[:, i * group_size:(i + 1) * group_size].copy_(int_weight_tmp.round_())
284283
if tail_flag:
285-
int_weight_tmp = weight[:, leng * group_size :].div_(
286-
scale[:, -1].unsqueeze(1)
287-
)
284+
int_weight_tmp = weight[:, leng * group_size:].div_(scale[:, -1].unsqueeze(1))
288285
if zp is not None:
289286
int_weight_tmp.add_(zp[:, -1].unsqueeze(1))
290-
int_weight[:, leng * group_size :].copy_(int_weight_tmp.round_())
287+
int_weight[:, leng * group_size:].copy_(int_weight_tmp.round_())
291288
return int_weight
292289

293290
def recover_qparms(self):
291+
294292
def recover_idx(ret_idx, k, blocksize):
295293
g_idx = torch.zeros(k, dtype=int)
296294
value_range = (k + blocksize - 1) // blocksize
@@ -328,18 +326,12 @@ def recover_int_weight(g_idx, int_weight):
328326
else:
329327
g_idx = None
330328
weight_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 6)
331-
weight_dtype = "".join(
332-
chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist()
333-
)
329+
weight_dtype = "".join(chr(ascii_code) for ascii_code in weight_dtype_ascii.tolist())
334330
bits = 4 if weight_dtype in ["nf4", "int4_clip", "fp4", "int4_fullrange"] else 8
335331
compute_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 7)
336-
compute_dtype = "".join(
337-
chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist()
338-
)
332+
compute_dtype = "".join(chr(ascii_code) for ascii_code in compute_dtype_ascii.tolist())
339333
scales_dtype_ascii = qbits.acquire_packed_weight_info(self.weight, 8)
340-
scales_dtype = "".join(
341-
chr(ascii_code) for ascii_code in scales_dtype_ascii.tolist()
342-
)
334+
scales_dtype = "".join(chr(ascii_code) for ascii_code in scales_dtype_ascii.tolist())
343335
if scales_dtype is None:
344336
assert False, "scales dtype only support fp32."
345337
scales = qbits.acquire_packed_weight_info(self.weight, 9)
@@ -356,9 +348,7 @@ def recover_int_weight(g_idx, int_weight):
356348

357349
revert_wei = torch.zeros(in_features, out_features, dtype=torch.float)
358350

359-
qbits.dequantize_packed_weight(
360-
self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype
361-
)
351+
qbits.dequantize_packed_weight(self.weight, revert_wei, False, compute_dtype, weight_dtype, scales_dtype)
362352

363353
int_weight = self.quant_weight_w_scale(
364354
revert_wei.t(),
@@ -426,9 +416,7 @@ def __init__(
426416
except:
427417
qbits_customop_available = False
428418
if lora_dropout > 0 and qbits_customop_available:
429-
self.lora_dropout = torch.nn.ModuleDict(
430-
{adapter_name: DropoutQBits(p=lora_dropout)}
431-
)
419+
self.lora_dropout = torch.nn.ModuleDict({adapter_name: DropoutQBits(p=lora_dropout)})
432420

433421
def merge(self, safe_merge: bool = False) -> None:
434422
"""Merge the active adapter weights into the base weights.
@@ -440,10 +428,8 @@ def merge(self, safe_merge: bool = False) -> None:
440428
NaNs. Defaults to `False`.
441429
"""
442430
if self.merged:
443-
print(
444-
f"Already following adapters were merged {','.join(self.merged_adapters)}. "
445-
f"You are now additionally merging {','.join(self.active_adapters)}."
446-
)
431+
print(f"Already following adapters were merged {','.join(self.merged_adapters)}. "
432+
f"You are now additionally merging {','.join(self.active_adapters)}.")
447433
w_dequant = torch.zeros(
448434
self.out_features,
449435
self.in_features,
@@ -468,8 +454,7 @@ def merge(self, safe_merge: bool = False) -> None:
468454

469455
if not torch.isfinite(orig_weights).all():
470456
raise ValueError(
471-
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
472-
)
457+
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken")
473458

474459
w_data = orig_weights
475460
else:
@@ -541,13 +526,10 @@ def unmerge(self) -> None:
541526
)
542527

543528
def get_delta_weight(self, adapter) -> torch.Tensor:
544-
return (
545-
transpose(
546-
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
547-
False,
548-
)
549-
* self.scaling[adapter]
550-
)
529+
return (transpose(
530+
self.lora_B[adapter].weight @ self.lora_A[adapter].weight,
531+
False,
532+
) * self.scaling[adapter])
551533

552534
def forward(self, x: torch.Tensor) -> torch.Tensor:
553535
if self.disable_adapters:
@@ -602,24 +584,18 @@ def _create_new_module(self, lora_config, adapter_name, target, **kwargs):
602584
bias = kwargs.pop("bias", False)
603585
in_features, out_features = target.in_features, target.out_features
604586
if kwargs["fan_in_fan_out"]:
605-
print(
606-
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
607-
"Setting fan_in_fan_out to False."
608-
)
587+
print("fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
588+
"Setting fan_in_fan_out to False.")
609589
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
610590
kwargs["compute_dtype"] = target.compute_dtype
611591
kwargs["compress_statistics"] = target.compress_statistics
612592
kwargs["weight_dtype"] = target.weight_dtype
613593
kwargs["scale_dtype"] = target.scale_dtype
614594
kwargs["blocksize"] = target.blocksize
615595
kwargs["scheme"] = target.scheme
616-
new_module = QuantizedLoraLinearQBits(
617-
adapter_name, in_features, out_features, bias=bias, **kwargs
618-
)
596+
new_module = QuantizedLoraLinearQBits(adapter_name, in_features, out_features, bias=bias, **kwargs)
619597
else:
620-
new_module = QBitsLoraModel._create_new_module_(
621-
lora_config, adapter_name, target, **kwargs
622-
)
598+
new_module = QBitsLoraModel._create_new_module_(lora_config, adapter_name, target, **kwargs)
623599
return new_module
624600

625601

0 commit comments

Comments
 (0)