Skip to content

Commit 25550b1

Browse files
committed
GPTQ updates
Summary: 1) reorganized GPTQ a) got rid of old GPTQ and renamed GPTQ_MT to GPTQ b) moved new GPTQ to prototype c) moved quantized linear modules in GPTQ.py to linear_quant_modules.py 2) removed dependence on lm_eval for input_recorder a) created new input recorder that doesn't depend on lm_eval b) made lm_eval input recorder depend on new generic input_recorder c) made TransformerEvalWrapper the base class and made d) updated apis generally to work with new input recorder LMEvalInputRecorder inherit from it instead of vice-versa 3) reorganized GPTQ tests a) moved tests from test_quant_api.py to test_gptq.py b) added new test that can be run in CI that doesn't depend on lm_eval/llama weights c) got rid of test_gptq_mt.py 4) added new documentation for lm_eval 5) GPTQ improvements a) reimplemented faster quant b) tested compilation of hessian calculation and parts of faster quant, generally they were slower. c) moved helper functions out of the class. They're largely generic and this is less cluttered. d) some improvements to the duplication checking and copying to be faster when possible e) fixed some bugs due to this not being in CI and things changing for int4wo tensor subclass. Test Plan: 1) `python test_gptq.py` note: the skipped test test_gptq_quantizer_int4_weight_only also ran. 2) I verified that all activation match between old GPTQ and current GPTQ 3) ```shell export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-gptq-64 --calibration_limit 10 export MODEL_REPO=meta-llama/Meta-Llama-3-8B python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-64 python eval.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization int4wo-gptq-64 --calibration_limit 10 ``` see README.md for results but they show GPTQ is working Reviewers: Subscribers: Tasks: Tags:
1 parent 7854249 commit 25550b1

File tree

15 files changed

+1529
-2532
lines changed

15 files changed

+1529
-2532
lines changed

test/quantization/test_gptq.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import unittest
2+
from pathlib import Path
3+
4+
import torch
5+
from torch.testing._internal.common_utils import TestCase
6+
7+
from torchao._models.llama.model import Transformer, TransformerBlock, prepare_inputs_for_model, ModelArgs
8+
from torchao._models.llama.tokenizer import get_tokenizer
9+
from torchao.utils import (
10+
TORCH_VERSION_AT_LEAST_2_4,
11+
)
12+
from torchao.quantization import quantize_, Int4WeightOnlyConfig
13+
from torchao.quantization.utils import compute_error
14+
15+
torch.manual_seed(0)
16+
17+
class TestGPTQ(TestCase):
18+
@unittest.skip("skipping until we get checkpoints for gpt-fast")
19+
def test_gptq_quantizer_int4_weight_only(self):
20+
from torchao._models._eval import (
21+
LMEvalInputRecorder,
22+
TransformerEvalWrapper,
23+
)
24+
from torchao.prototype.GPTQ.GPTQ import Int4WeightOnlyGPTQQuantizer
25+
26+
precision = torch.bfloat16
27+
device = "cuda"
28+
checkpoint_path = Path("../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
29+
model = Transformer.from_name(checkpoint_path.parent.name)
30+
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
31+
model.load_state_dict(checkpoint, assign=True)
32+
model = model.to(dtype=precision, device="cpu")
33+
model.eval()
34+
35+
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
36+
assert tokenizer_path.is_file(), tokenizer_path
37+
tokenizer = get_tokenizer( # pyre-ignore[28]
38+
tokenizer_path,
39+
"Llama-2-7b-chat-hf",
40+
)
41+
groupsize = 64
42+
blocksize = 128
43+
percdamp = 0.01
44+
calibration_tasks = ["wikitext"]
45+
calibration_limit = 1
46+
calibration_seq_length = 100
47+
input_prep_func = prepare_inputs_for_model
48+
pad_calibration_inputs = False
49+
inputs = (
50+
LMEvalInputRecorder(
51+
tokenizer,
52+
calibration_seq_length,
53+
input_prep_func,
54+
model.config.vocab_size,
55+
pad_calibration_inputs,
56+
device="cpu",
57+
)
58+
.record_inputs(
59+
calibration_tasks,
60+
calibration_limit,
61+
)
62+
.get_inputs()
63+
)
64+
65+
quantizer = Int4WeightOnlyGPTQQuantizer(
66+
groupsize,
67+
blocksize,
68+
percdamp,
69+
)
70+
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
71+
72+
model = quantizer.quantize(model, *inputs).cuda()
73+
74+
model.reset_caches()
75+
with torch.device("cuda"):
76+
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)
77+
78+
limit = 1
79+
result = TransformerEvalWrapper(
80+
model.cuda(),
81+
tokenizer,
82+
model.config.block_size,
83+
prepare_inputs_for_model,
84+
device,
85+
).run_eval(
86+
["wikitext"],
87+
limit,
88+
)
89+
90+
assert result["results"]["wikitext"]["word_perplexity,none"] < 7.77, (
91+
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
92+
)
93+
94+
95+
class TestMultiTensorFlow(TestCase):
96+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
97+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
98+
def test_multitensor_add_tensors(self):
99+
from torchao.prototype.GPTQ.GPTQ import MultiTensor
100+
101+
tensor1 = torch.randn(3, 3)
102+
tensor2 = torch.randn(3, 3)
103+
mt = MultiTensor(tensor1)
104+
mt.add_tensors(tensor2)
105+
self.assertEqual(mt.count, 2)
106+
self.assertTrue(torch.equal(mt.values[0], tensor1))
107+
self.assertTrue(torch.equal(mt.values[1], tensor2))
108+
109+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
110+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
111+
def test_multitensor_pad_unpad(self):
112+
from torchao.prototype.GPTQ.GPTQ import MultiTensor
113+
114+
tensor1 = torch.randn(3, 3)
115+
mt = MultiTensor(tensor1)
116+
mt.pad_to_length(3)
117+
self.assertEqual(mt.count, 3)
118+
mt.unpad()
119+
self.assertEqual(mt.count, 1)
120+
121+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
122+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
123+
def test_multitensor_inplace_operation(self):
124+
from torchao.prototype.GPTQ.GPTQ import MultiTensor
125+
126+
tensor1 = torch.ones(3, 3)
127+
mt = MultiTensor(tensor1)
128+
mt += 1 # In-place addition
129+
self.assertTrue(torch.equal(mt.values[0], torch.full((3, 3), 2)))
130+
131+
class TestMultiTensorInputRecorder(TestCase):
132+
def test_multitensor_input_recorder(self):
133+
from torchao.prototype.GPTQ.GPTQ import MultiTensorInputRecorder, MultiTensor
134+
135+
input_recorder = MultiTensorInputRecorder()
136+
x = torch.randn(3,3)
137+
in1 = ([1], torch.randn(3,3), (1, "dog", torch.randn(3,3)), torch.float)
138+
in2 = ([1], torch.randn(3,3), (1, "dog", torch.randn(3,3)), torch.float)
139+
140+
input_recorder(*in1)
141+
input_recorder(*in2)
142+
143+
MT_input = input_recorder.get_recorded_inputs()
144+
145+
self.assertEqual(MT_input[0], [1])
146+
self.assertTrue(isinstance(MT_input[1], MultiTensor))
147+
self.assertTrue(isinstance(MT_input[2], tuple))
148+
self.assertEqual(MT_input[2][0], 1)
149+
self.assertEqual(MT_input[2][1], "dog")
150+
self.assertTrue(isinstance(MT_input[2][2], MultiTensor))
151+
self.assertEqual(MT_input[3], torch.float)
152+
153+
154+
def test_gptq_with_input_recorder(self):
155+
from torchao.prototype.GPTQ.GPTQ import MultiTensorInputRecorder, Int4WeightOnlyGPTQQuantizer
156+
torch.set_default_dtype(torch.bfloat16)
157+
158+
config = ModelArgs(n_layer = 2)
159+
160+
with torch.device("cuda"):
161+
model = Transformer(config)
162+
model.setup_caches(max_batch_size=2, max_seq_length=100)
163+
idx = torch.randint(1, 10000, (10, 2, 50)).to(torch.int32)
164+
test_input = prepare_inputs_for_model(idx[0])
165+
import copy
166+
model2 = copy.deepcopy(model)
167+
out = model(*test_input)
168+
quantize_(model2, Int4WeightOnlyConfig())
169+
170+
outq = model2(*test_input)
171+
del model2
172+
173+
input_recorder = MultiTensorInputRecorder()
174+
for i in range(10):
175+
input = prepare_inputs_for_model(idx[i])
176+
input_recorder(*input)
177+
178+
multi_tensor_inputs = input_recorder.get_recorded_inputs()
179+
180+
quantizer = Int4WeightOnlyGPTQQuantizer()
181+
182+
quantizer.quantize(model, multi_tensor_inputs)
183+
184+
outgptq = model(*test_input)
185+
186+
self.assertGreater(compute_error(outgptq, out), 30)
187+
self.assertGreater(compute_error(outgptq, out), compute_error(outq, out))
188+
189+
if __name__ == "__main__":
190+
unittest.main()

0 commit comments

Comments
 (0)