Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit 81d4c56

Browse files
authored
Support weight-only kernel with IPEX for intel GPU (#1153)
1 parent 957785d commit 81d4c56

File tree

20 files changed

+1023
-218
lines changed

20 files changed

+1023
-218
lines changed

.github/workflows/script/formatScan/nlp_dict.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2428,6 +2428,9 @@ aj
24282428
Życzyński
24292429
Zyczynski
24302430
CES
2431+
DPCPP
2432+
QLLM
2433+
Qwen
24312434
Chroma
24322435
HuggingFacePipeline
24332436
Langchain
@@ -2438,4 +2441,4 @@ VectorStoreRetriever
24382441
langchain
24392442
retrievalQA
24402443
vectorstore
2441-
vectorstores
2444+
vectorstores

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
tests/*.pt
12
/intel_extension_for_transformers/llm/runtime/graph/*
23
!/intel_extension_for_transformers/llm/runtime/graph/*.*
34
!/intel_extension_for_transformers/llm/runtime/graph/*/

docs/weightonlyquant.md

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ Weight Only Quantization (WOQ)
55

66
2. [Supported Framework Model Matrix](#supported-framework-model-matrix)
77

8-
3. [Examples](#examples)
8+
3. [Examples For CPU](#examples-for-cpu)
9+
10+
4. [Examples For GPU](#examples-for-gpu)
911

1012
## Introduction
1113

@@ -17,7 +19,12 @@ As large language models (LLMs) become more prevalent, there is a growing need f
1719
| RTN | ✔ | ✔ |
1820
| AWQ | ✔ | stay tuned |
1921
| TEQ | ✔ | stay tuned |
20-
| GPTQ | stay tuned | ✔ |
22+
| GPTQ | ✔ | ✔ |
23+
24+
| Support Device | RTN | AWQ | TEQ | GPTQ |
25+
|:--------------:|:----------:|:----------:|:----------:|:----:|
26+
| CPU | ✔ | ✔ | ✔ | ✔ |
27+
| GPU | ✔ | stay tuned | stay tuned | stay tuned |
2128
> **RTN:** A quantification method that we can think of very intuitively. It does not require additional datasets and is a very fast quantization method. Generally speaking, RTN will convert the weight into a uniformly distributed integer data type, but some algorithms, such as Qlora, propose a non-uniform NF4 data type and prove its theoretical optimality.
2229
2330
> **GPTQ:** A new one-shot weight quantization method based on approximate second-order information, that is both highly-accurate and highly efficient. The weights of each column are updated based on the fixed-scale pseudo-quantization error and the inverse of the Hessian matrix calculated from the activations. The updated columns sharing the same scale may generate a new max/min value, so the scale needs to be saved for restoration.
@@ -27,7 +34,7 @@ As large language models (LLMs) become more prevalent, there is a growing need f
2734
> **TEQ:** A trainable equivalent transformation that preserves the FP32 precision in weight-only quantization. It is inspired by AWQ while providing a new solution to search for the optimal per-channel scaling factor between activations and weights.
2835
2936

30-
## Examples
37+
## Examples For CPU
3138

3239
Our motivation is improve CPU support for weight only quantization, since `bitsandbytes` only support CUDA GPU device. We have extended the `from_pretrained` function so that `quantization_config` can accept [`WeightOnlyQuantConfig`](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/utils/quantization_config.py#L28) to implement conversion on the CPU. We not only support PyTorch but also provide LLM Runtime backend based cpp programming language. Here are the example codes.
3340

@@ -133,6 +140,85 @@ loaded_model = AutoModelForCausalLM.from_pretrained(saved_dir)
133140
| Inference Framework | Load GPT-Q model from HuggingFace | Load the saved low-precision model from ITREX |
134141
|:--------------:|:----------:|:----------:|
135142
| LLM Runtime (use_llm_runtime=True) | ✔ | ✔ |
136-
| PyTorch (use_llm_runtime=False) | stay tuned | ✔ |
143+
| PyTorch (use_llm_runtime=False) | ✔ | ✔ |
144+
145+
> Note: For LLM runtime model loading usage, please refer to [graph readme](../intel_extension_for_transformers/llm/runtime/graph/README.md#2-run-llm-with-transformer-based-api)
146+
147+
## Examples For GPU
148+
Intel-extension-for-transformers implement weight-only quantization for intel GPU(PVC and ARC) with [Intel-extension-for-pytorch](https://github.com/intel/intel-extension-for-pytorch). Currently, the Linear op kernel of Weight-only quantization is implemented in the Intel-extension-for-pytorch branch: "dev/QLLM".
149+
We support experimental woq inference on intel GPU(PVC and ARC) with replacing Linear op in PyTorch. Validated models: Qwen-7B, GPT-J-6B.
150+
Here are the example codes.
151+
152+
#### Prepare Dependency Packages
153+
1. Install Oneapi Package
154+
Weight-only quantization ops only exist in "dev/QLLM" branch on the intel-extension-for-pytorch. It needs to be compiled with the Oneapi DPCPP compiler. Please follow [the link](https://www.intel.com/content/www/us/en/developer/articles/guide/installation-guide-for-oneapi-toolkits.html) to install the OneAPI to "/opt/intel folder".
137155

138-
> Note: Only supports CPU device for now. For LLM runtime model loading usage, please refer to [graph readme](../intel_extension_for_transformers/llm/runtime/graph/README.md#2-run-llm-with-transformer-based-api)
156+
2. Build and Install PyTorch and Intel-extension-for-pytorch
157+
```
158+
python -m pip install torch==2.1.0a0 -f https://developer.intel.com/ipex-whl-stable-xpu
159+
160+
source /opt/intel/oneapi/setvars.sh
161+
162+
git clone https://github.com/intel-innersource/frameworks.ai.pytorch.ipex-gpu.git ipex-gpu
163+
cd ipex-gpu
164+
git checkout -b dev/QLLM origin/dev/QLLM
165+
git submodule update --init --recursive
166+
167+
Pip install -r requirements.txt
168+
python setup.py install
169+
```
170+
171+
3. Install Intel-extension-for-transformers and Neural-compressor
172+
```
173+
pip install neural-compressor
174+
pip install intel-extension-for-transformers
175+
```
176+
177+
4. Run The Example
178+
```
179+
import intel_extension_for_pytorch as ipex
180+
from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM
181+
from intel_extension_for_transformers.transformers import WeightOnlyQuantConfig
182+
from transformers import AutoTokenizer
183+
184+
device_map = "xpu"
185+
model_name ="hf-internal-testing/tiny-random-gptj"
186+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
187+
prompt = "how to test the code?"
188+
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device_map)
189+
190+
config = WeightOnlyQuantConfig(weight_dtype="int4_fullrange",
191+
algorithm="RTN",
192+
group_size=32,
193+
compute_dtype="fp16",
194+
scale_dtype="fp16")
195+
qmodel = AutoModelForCausalLM.from_pretrained(model_name, use_llm_runtime=False,
196+
device_map=device_map,quantization_config=config,
197+
trust_remote_code=True, torch_dtype=torchfloat16)
198+
199+
# saving model, it should be executed before ipex.optimize_transformers function is called.
200+
qmodel.save_pretrained("saved_dir")
201+
202+
# optimize the model with ipex, it will improve performance.
203+
qmodel = ipex.optimize_transformers(qmodel, inplace=True, dtype=torch.float16, woq=True, device=device_map)
204+
205+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams)
206+
output = user_model.generate(
207+
input_ids, max_new_tokens=32, **generate_kwargs
208+
)
209+
gen_text = tokenizer.batch_decode(
210+
output, skip_special_tokens=True
211+
)
212+
213+
# loading quantized model
214+
loaded_model = AutoModelForCausalLM.from_pretrained(
215+
"saved_dir", trust_remote_code=True, device_map=device_map
216+
)
217+
218+
# Before executed the loaded model, you can call ipex.optimize_transformers function.
219+
loaded_model = ipex.optimize_transformers(loaded_model, inplace=True, dtype=torch.float16, woq=True, device=device_map)
220+
221+
```
222+
>Note:
223+
> * Saving quantized model should be executed before the optimize_transformers function is called.
224+
> * The optimize_transformers function is designed to optimize transformer-based models within frontend Python modules, with a particular focus on Large Language Models (LLMs). It provides optimizations for both model-wise and content-generation-wise. The detail of `optimize_transformers`, please refer to [the link](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-main/docs/tutorials/llm/llm_optimize_transformers.md).

env_gpu.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Set up the environment for Intel oneAPI DPC++/C++ Compiler
2+
# ONEAPI_INSTALL_PATH below assumes you installed to the default folder /opt/intel/oneapi
3+
# If you customized the installation folder, please update ONEAPI_INSTALL_PATH to your custom folder
4+
ONEAPI_INSTALL_PATH=/opt/intel/oneapi
5+
source ${ONEAPI_INSTALL_PATH}/setvars.sh

examples/huggingface/pytorch/text-generation/quantization/run_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@
272272
)
273273
elif args.woq:
274274
if args.woq_algo == "GPTQ":
275-
gptq_recipes = {
275+
algorithm_args = {
276276
"act_order": args.gptq_actorder,
277277
"percdamp": args.gptq_percdamp,
278278
"block_size": args.gptq_block_size,
@@ -288,7 +288,7 @@
288288
group_size=args.gptq_block_size,
289289
algorithm=args.woq_algo,
290290
tokenizer=tokenizer,
291-
gptq_recipes=gptq_recipes,
291+
algorithm_args=algorithm_args,
292292
)
293293
else:
294294
quantization_config = WeightOnlyQuantConfig(
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import argparse
2+
import re
3+
import time
4+
import json
5+
import torch
6+
from transformers import AutoConfig, AutoTokenizer
7+
from transformers.generation import GenerationConfig
8+
import intel_extension_for_pytorch as ipex
9+
from intel_extension_for_transformers.transformers import AutoModelForCausalLM, WeightOnlyQuantConfig
10+
from intel_extension_for_transformers.llm.quantization.utils import convert_dtype_str2torch
11+
from transformers.utils import check_min_version
12+
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument(
15+
"--model", nargs="?", default="Qwen/Qwen-7B-Chat", const="Qwen/Qwen-7B-Chat"
16+
)
17+
parser.add_argument("--revision", default=None, type=str)
18+
parser.add_argument("--trust_remote_code", default=True)
19+
parser.add_argument(
20+
"--dataset", nargs="?", default="NeelNanda/pile-10k", const="NeelNanda/pile-10k"
21+
)
22+
parser.add_argument(
23+
"--max-new-tokens", default=32, type=int, help="output max new tokens"
24+
)
25+
parser.add_argument(
26+
"--num_beams", default=1, type=int, help="number of beams"
27+
)
28+
parser.add_argument("--output_dir", nargs="?", default="./saved_results")
29+
parser.add_argument("--int8", action="store_true")
30+
parser.add_argument(
31+
"--int8_bf16_mixed",
32+
action="store_true",
33+
help="by default it is int8-fp32 mixed, to enable int8 mixed amp bf16 (work on platforms like SPR)",
34+
)
35+
parser.add_argument("--peft_model_id", type=str, default=None, help="model_name_or_path of peft model")
36+
# ============Benchmark configs==============
37+
parser.add_argument("--benchmark", action="store_true")
38+
parser.add_argument("--do_profiling", action="store_true")
39+
parser.add_argument("--profile_token_latency", action="store_true")
40+
parser.add_argument("--iters", default=10, type=int, help="num iter")
41+
parser.add_argument("--num_warmup", default=3, type=int, help="num warmup")
42+
# ============Accuracy configs==============
43+
parser.add_argument("--accuracy", action="store_true")
44+
parser.add_argument("--batch_size", default=1, type=int,
45+
help="batch size num.")
46+
parser.add_argument("--save_accuracy_path", default=None,
47+
help="Save accuracy results path.")
48+
parser.add_argument("--tasks", nargs='+', default=["lambada_openai"], type=str, \
49+
help="tasks list for accuracy validation")
50+
# ============WeightOnlyQuant configs===============
51+
parser.add_argument("--woq", action="store_true")
52+
parser.add_argument("--woq_algo", default="RTN", choices=['RTN'],
53+
help="Weight-only parameter.")
54+
parser.add_argument("--woq_dtype", type=str, default="int4_fullrange",
55+
choices=["int4_fullrange"])
56+
parser.add_argument("--woq_group_size", type=int, default=32)
57+
parser.add_argument("--woq_scheme", default="sym")
58+
parser.add_argument("--woq_enable_mse_search", action="store_true")
59+
parser.add_argument("--device", default="xpu")
60+
parser.add_argument("--compute_dtype", default="fp16")
61+
# ============BitsAndBytes configs==============
62+
parser.add_argument("--bitsandbytes", action="store_true")
63+
parser.add_argument("--load_in_4bit", type=bool, default=False)
64+
parser.add_argument("--load_in_8bit", type=bool, default=False)
65+
# =======================================
66+
args = parser.parse_args()
67+
torch_dtype = convert_dtype_str2torch(args.compute_dtype)
68+
69+
# transformers version >= 4.32.0 contained the mpt modeling definition.
70+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
71+
check_min_version("4.31.0")
72+
73+
# get model config
74+
config = AutoConfig.from_pretrained(
75+
args.model,
76+
use_cache=True, # to use kv cache.
77+
trust_remote_code=args.trust_remote_code,
78+
revision=args.revision,
79+
)
80+
generation_config = GenerationConfig.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
81+
generation_config.do_sample = False
82+
user_model = None
83+
84+
# tokenizer
85+
if config.model_type == "llama":
86+
from transformers import LlamaTokenizer
87+
tokenizer = LlamaTokenizer.from_pretrained(args.model)
88+
else:
89+
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=args.trust_remote_code)
90+
91+
quantization_config = None
92+
if args.woq:
93+
quantization_config = WeightOnlyQuantConfig(
94+
compute_dtype=args.compute_dtype, weight_dtype=args.woq_dtype,
95+
group_size=args.woq_group_size, scale_dtype=args.compute_dtype
96+
) #default is A16W4G16
97+
98+
# get model
99+
if quantization_config is not None:
100+
user_model = AutoModelForCausalLM.from_pretrained(args.model,
101+
device_map=args.device,
102+
quantization_config=quantization_config,
103+
trust_remote_code=args.trust_remote_code,
104+
fp16=True,
105+
use_llm_runtime=False
106+
)
107+
elif args.load_in_4bit or args.load_in_8bit:
108+
# CPU device usage is provided by intel-extension-for-transformers.
109+
user_model = AutoModelForCausalLM.from_pretrained(args.model,
110+
device_map=args.device,
111+
load_in_4bit=args.load_in_4bit,
112+
load_in_8bit=args.load_in_8bit,
113+
use_llm_runtime=False
114+
)
115+
if user_model is not None:
116+
user_model.save_pretrained(args.output_dir)
117+
tokenizer.save_pretrained(args.output_dir)
118+
119+
if args.benchmark:
120+
prompt = "它完成了,并提交了。你可以在Android和网络上玩美味生存。在网络上玩是有效的,但你必须模拟多次触摸才能移动桌子."
121+
122+
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)
123+
print("---- Prompt size:", input_size)
124+
125+
user_model = AutoModelForCausalLM.from_pretrained(
126+
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
127+
if user_model is None else user_model
128+
user_model = ipex.optimize_transformers(
129+
user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype)
130+
# start
131+
num_iter = args.iters
132+
num_warmup = args.num_warmup
133+
prompt = [prompt] * args.batch_size
134+
amp_enabled = True
135+
amp_dtype = torch_dtype
136+
137+
generate_kwargs = dict(do_sample=False, temperature=0.9, num_beams=args.num_beams)
138+
if args.profile_token_latency:
139+
generate_kwargs["token_latency"] = True
140+
141+
total_time = 0.0
142+
total_list = []
143+
with torch.inference_mode(), torch.no_grad(), torch.autocast(
144+
device_type=args.device,
145+
enabled=amp_enabled,
146+
dtype=amp_dtype if amp_enabled else None,
147+
):
148+
for i in range(num_iter + num_warmup):
149+
with torch.autograd.profiler_legacy.profile(enabled=args.do_profiling, use_xpu=(args.device=="xpu"), record_shapes=False) as prof:
150+
input_ids = tokenizer(
151+
prompt, return_tensors="pt").input_ids.to(args.device)
152+
tic = time.time()
153+
output = user_model.generate(
154+
input_ids, max_new_tokens=int(args.max_new_tokens), **generate_kwargs
155+
)
156+
toc = time.time()
157+
gen_ids = output[0] if args.profile_token_latency else output
158+
gen_text = tokenizer.batch_decode(
159+
gen_ids, skip_special_tokens=True)
160+
if args.device == "xpu":
161+
torch.xpu.synchronize()
162+
if args.do_profiling and i >= num_warmup and (i == num_warmup or i == num_iter + num_warmup - 1):
163+
print(f"Save pt for iter {i}")
164+
torch.save(prof.key_averages().table(
165+
sort_by="self_xpu_time_total"), f"./profile_{i}.pt")
166+
# torch.save(prof.table(sort_by="id", row_limit=-1),
167+
# './profile_id.pt')
168+
# torch.save(prof.key_averages(
169+
# group_by_input_shape=True).table(), "./profile_detail.pt")
170+
prof.export_chrome_trace(f"./trace_{i}.json")
171+
input_tokens_lengths = [x.shape[0] for x in input_ids]
172+
output_tokens_lengths = [x.shape[0] for x in gen_ids]
173+
total_new_tokens = [
174+
o - i if user_model.config.model_type != "t5" else o
175+
for i, o in zip(input_tokens_lengths, output_tokens_lengths)
176+
]
177+
print(gen_text, total_new_tokens, flush=True)
178+
print("Iteration: %d, Time: %.6f sec" % (i, toc - tic), flush=True)
179+
if i >= num_warmup:
180+
total_time += toc - tic
181+
if args.profile_token_latency:
182+
total_list.append(output[1])
183+
184+
print("\n", "-" * 10, "Summary:", "-" * 10)
185+
latency = total_time / (num_iter - num_warmup)
186+
print("Inference latency: %.5f sec." % latency)
187+
throughput = (args.max_new_tokens + input_size) / latency
188+
print("Average throughput: {} samples/sec".format(throughput))
189+
190+
if args.profile_token_latency:
191+
import numpy as np
192+
from itertools import chain
193+
194+
first_latency = np.mean([x[0] for x in total_list])
195+
average_2n = list(chain(*[x[1:] for x in total_list]))
196+
average_2n.sort()
197+
average_2n_latency = np.mean(average_2n)
198+
print("First token average latency: %.5f sec." % first_latency)
199+
print("Average 2... latency: %.5f sec." % average_2n_latency)
200+
print(total_list)
201+
202+
203+
if args.accuracy:
204+
from intel_extension_for_transformers.llm.evaluation.lm_eval import evaluate
205+
user_model = AutoModelForCausalLM.from_pretrained(
206+
args.model, trust_remote_code=args.trust_remote_code, device_map=args.device, torch_dtype=torch_dtype) \
207+
if user_model is None else user_model
208+
user_model = ipex.optimize_transformers(
209+
user_model.eval(), device=args.device, inplace=True, woq=True, dtype=torch_dtype)
210+
results = evaluate(
211+
model="hf-causal",
212+
model_args='pretrained='+args.model+',tokenizer='+args.model+',dtype=float32',
213+
user_model=user_model,
214+
batch_size=args.batch_size,
215+
tasks=args.tasks,
216+
device=args.device
217+
)
218+
dumped = json.dumps(results, indent=2)
219+
if args.save_accuracy_path:
220+
with open(args.save_accuracy_path, "w") as f:
221+
f.write(dumped)
222+
for task_name in args.tasks:
223+
if task_name == "wikitext":
224+
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["word_perplexity"]))
225+
else:
226+
print("Accuracy for %s is: %s" % (task_name, results["results"][task_name]["acc"]))
227+

0 commit comments

Comments
 (0)