Skip to content

Commit 93eb8f7

Browse files
authored
Fix run_compressed tests (#1246)
Summary - These tests were previously approved but something weird happened when they were landing and they were reverted to their old cases...not sure how(?!) - This brings back the appropriate cases written by @horheynm - This would have caught the lm-eval test issue that Dan is reporting as well Requires: [fix_compressed_linear](neuralmagic/compressed-tensors#278)
1 parent 2f7c620 commit 93eb8f7

File tree

3 files changed

+137
-89
lines changed

3 files changed

+137
-89
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cadence: "commit"
1+
cadence: "nightly"
22
test_type: "regression"
33
compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-compressed
44
uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-uncompressed
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
cadence: "commit"
1+
cadence: "nightly"
22
test_type: "regression"
33
compressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-compressed
44
uncompressed_model_stub: nm-testing/TinyLlama-1.1B-Chat-v1.0-W8A8-Dynamic-Per-Token-uncompressed
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,182 @@
1-
import copy
21
import shutil
32
import tempfile
43
import unittest
54

65
import torch
7-
from compressed_tensors import QUANTIZATION_CONFIG_NAME
8-
from compressed_tensors.compressors import ModelCompressor
9-
from compressed_tensors.quantization import QuantizationStatus
6+
from compressed_tensors.linear.compressed_linear import CompressedLinear
7+
from compressed_tensors.quantization.utils import iter_named_leaf_modules
108
from parameterized import parameterized_class
11-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
1210
from transformers.utils.quantization_config import CompressedTensorsConfig
1311

1412
from tests.testing_utils import parse_params, requires_gpu
1513

16-
CONFIG_DIR = "tests/llmcompressor/transformers/compression/decompression_configs"
14+
COMPRESSED_LINEAR_CONFIG_DIR = (
15+
"tests/llmcompressor/transformers/compression/run_compressed_configs"
16+
)
1717

1818

1919
@requires_gpu
20-
@parameterized_class(parse_params(CONFIG_DIR))
21-
class TestDecompression(unittest.TestCase):
20+
@parameterized_class(parse_params(COMPRESSED_LINEAR_CONFIG_DIR))
21+
class Test_Decompressed_Linear_Uncompressed_Linear(unittest.TestCase):
2222
"""
23-
Check that HFQuantizer decompression is working as expected.
24-
Manually decompress a compressed model and compare the generations
23+
Uncompressed-Linear-forward decompressed-Linear-foward check
2524
26-
Decompression:
27-
Given a skeleton model and path to the optimized model,
28-
write the optimized model's safetensors to the skeleton model and decompress
29-
Ex. write weight_scale to the skeleton model and then convert from fp4 to fp16
25+
Uncompressed: Optimized model saved as run_compressed=False, no need to decompress
26+
Decompressed: Optimized model saved as run_compressed=True, and decompressed using
27+
AutoModelForCausalLM decompression
28+
29+
AutoModelForCausalLM decompression diagram flow https://tinyurl.com/2ynb6wbu
3030
3131
"""
3232

3333
compressed_model_stub = None
34-
skeleton_model_stub = None
35-
36-
SAMPLE_INPUTS = [
37-
"I love 4-bit quantization because",
38-
"What is the capital of France?",
39-
"def fibonacci(n):",
40-
]
34+
uncompressed_model_stub = None
4135

4236
@classmethod
43-
def setUpClass(self):
44-
self.test_dir = tempfile.mkdtemp()
45-
self.tokenizer = AutoTokenizer.from_pretrained(self.compressed_model_stub)
37+
def setUpClass(cls):
38+
cls.test_dir = tempfile.mkdtemp()
4639

47-
# Decompress using HFQuantizer from AutoModelForCausalLM
48-
self.decompressed_model_hf_quantizer = AutoModelForCausalLM.from_pretrained(
49-
self.compressed_model_stub,
40+
quantization_config = CompressedTensorsConfig(run_compressed=False)
41+
42+
# Decompressed using HFQuantizer
43+
# Linear foward
44+
cls.decompressed_model = AutoModelForCausalLM.from_pretrained(
45+
cls.compressed_model_stub,
5046
torch_dtype="auto",
5147
device_map="auto",
52-
quantization_config=CompressedTensorsConfig(run_compressed=False),
48+
quantization_config=quantization_config,
5349
)
5450

55-
# Manually decompress this model
56-
self.dense_model = AutoModelForCausalLM.from_pretrained(
57-
self.skeleton_model_stub,
58-
torch_dtype=self.decompressed_model_hf_quantizer.dtype,
59-
device_map=self.decompressed_model_hf_quantizer.device,
51+
# Load model as is at the uncompressed state
52+
# Linear forward
53+
cls.uncompressed_model = AutoModelForCausalLM.from_pretrained(
54+
cls.uncompressed_model_stub,
55+
torch_dtype=cls.decompressed_model.dtype,
56+
device_map=cls.decompressed_model.device,
6057
)
6158

62-
# decompression from HFQuantizer should populate weight_scale
63-
assert hasattr(
64-
self.decompressed_model_hf_quantizer.model.layers[0].self_attn.q_proj,
65-
"weight_scale",
66-
)
59+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.compressed_model_stub)
6760

68-
# dense model should not have weight_scale populated
69-
assert not hasattr(
70-
self.dense_model.model.layers[0].self_attn.q_proj, "weight_scale"
71-
)
61+
def test_compressed_matches_decompressed(self):
62+
SAMPLE_INPUT = [
63+
"I love 4-bit quantization because",
64+
"What is the capital of France?",
65+
"def fibonacci(n):",
66+
]
67+
68+
decompressed_device = self.decompressed_model.device
69+
uncompressed_device = self.uncompressed_model.device
7270

73-
config = AutoConfig.from_pretrained(self.compressed_model_stub)
71+
# overwrite weights in cpu to cuda
72+
self.decompressed_model = self.decompressed_model.to(decompressed_device)
73+
self.uncompressed_model = self.uncompressed_model.to(uncompressed_device)
7474

75-
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
76-
self.compressor = ModelCompressor.from_compression_config(compression_config)
77-
self.compressor.quantization_config.quantization_status = (
78-
QuantizationStatus.FROZEN
75+
inputs = self.tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(
76+
decompressed_device
7977
)
8078

81-
# use the model_path to load the decompressed weights into dense_model
82-
dense_model = copy.deepcopy(self.dense_model)
79+
decompressed_output = self.decompressed_model.generate(**inputs, max_length=50)
8380

84-
# overwrite the weights of the dense model
85-
self.compressor.decompress(
86-
model_path=self.compressed_model_stub,
87-
model=self.dense_model,
88-
)
81+
inputs = inputs.to(uncompressed_device)
8982

90-
# self.dense_model should be decompressed
91-
assert dense_model is not self.dense_model
83+
uncompressed_output = self.uncompressed_model.generate(**inputs, max_length=50)
9284

93-
self.decompressed_model_manual = self.dense_model
85+
for idx in range(len(SAMPLE_INPUT)):
86+
assert torch.equal(decompressed_output[idx], uncompressed_output[idx])
9487

95-
assert hasattr(
96-
self.decompressed_model_manual.model.layers[0].self_attn.q_proj,
97-
"weight_scale",
98-
)
88+
@classmethod
89+
def tearDownClass(cls):
90+
shutil.rmtree(cls.test_dir)
91+
del cls.decompressed_model
92+
del cls.uncompressed_model
93+
torch.cuda.empty_cache()
94+
95+
96+
@requires_gpu
97+
@parameterized_class(parse_params(COMPRESSED_LINEAR_CONFIG_DIR))
98+
class Test_Compressed_CompressedLinear_Decompressed_Linear(unittest.TestCase):
99+
"""
100+
Compressed-CompresesdLinear, Decompressed-Linear check
101+
102+
Compressed: Optimized model saved as run_compressed=True, no decompression
103+
Decompressed: Optimized model saved as run_compressed=True, and decompressed using
104+
AutoModelForCausalLM decompression
105+
106+
All compressed model should have CompressedLinear, which has its custom forward call
107+
108+
"""
109+
110+
compressed_model_stub = None
111+
112+
@classmethod
113+
def setUpClass(cls):
114+
cls.test_dir = tempfile.mkdtemp()
99115

100-
def test_hf_quantizer_decompress_match_manual_decompress(self):
101-
manual_device = self.decompressed_model_manual.device
102-
decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.device
116+
# Should have CompressedLinear modules
117+
# Compressed Linear forward
118+
cls.compressed_model = AutoModelForCausalLM.from_pretrained(
119+
cls.compressed_model_stub,
120+
torch_dtype="auto",
121+
device_map="auto",
122+
)
103123

104-
self.decompressed_model_manual = self.decompressed_model_manual.to(
105-
manual_device
124+
# Should just be linear modules
125+
# Linear forward
126+
quantization_config = CompressedTensorsConfig(run_compressed=False)
127+
cls.decompressed_model = AutoModelForCausalLM.from_pretrained(
128+
cls.compressed_model_stub,
129+
torch_dtype=cls.compressed_model.dtype,
130+
device_map=cls.compressed_model.device,
131+
quantization_config=quantization_config,
106132
)
107-
self.decompressed_model_hf_quantizer = self.decompressed_model_hf_quantizer.to(
108-
decompressed_model_hf_quantizer
133+
134+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.compressed_model_stub)
135+
136+
def test_compressed_linear_modules_exist(self):
137+
compressed_linear_counts = 0
138+
for _, submodule in iter_named_leaf_modules(
139+
self.compressed_model,
140+
):
141+
if isinstance(submodule, CompressedLinear):
142+
compressed_linear_counts += 1
143+
144+
# some linear models are not compressed - ex. lm_head
145+
assert compressed_linear_counts > 0
146+
147+
def test_compressed_matches_decompressed__hf_quantizer(self):
148+
SAMPLE_INPUT = [
149+
"I love 4-bit quantization because",
150+
"What is the capital of France?",
151+
"def fibonacci(n):",
152+
]
153+
154+
decompressed_device = self.decompressed_model.device
155+
compressed_device = self.compressed_model.device
156+
157+
# overwrite weights in cpu to cuda
158+
self.decompressed_model = self.decompressed_model.to(decompressed_device)
159+
self.compressed_model = self.compressed_model.to(compressed_device)
160+
161+
inputs = self.tokenizer(SAMPLE_INPUT, return_tensors="pt", padding=True).to(
162+
decompressed_device
109163
)
110164

111-
for input in self.SAMPLE_INPUTS:
112-
inputs = self.tokenizer(input, return_tensors="pt", padding=True).to(
113-
self.decompressed_model_manual.device
114-
)
115-
inputs = inputs.to(self.decompressed_model_manual.device)
165+
decompressed_model_out = self.decompressed_model.generate(
166+
**inputs, max_length=50
167+
)
116168

117-
decompressed_model_manual_output = self.decompressed_model_manual.generate(
118-
**inputs, max_length=50
119-
)
169+
inputs = inputs.to(compressed_device)
120170

121-
decompressed_model_hf_quantizer_out = (
122-
self.decompressed_model_hf_quantizer.generate(**inputs, max_length=50)
123-
)
171+
compressed_model_out = self.compressed_model.generate(**inputs, max_length=50)
124172

125-
assert torch.equal(
126-
decompressed_model_hf_quantizer_out, decompressed_model_manual_output
127-
)
173+
# Compare outputs for each input
174+
for idx in range(len(SAMPLE_INPUT)):
175+
torch.equal(compressed_model_out[idx], decompressed_model_out[idx])
128176

129177
@classmethod
130-
def tearDownClass(self):
131-
shutil.rmtree(self.test_dir)
132-
del self.dense_model
133-
del self.decompressed_model_hf_quantizer
134-
del self.decompressed_model_manual
178+
def tearDownClass(cls):
179+
shutil.rmtree(cls.test_dir)
180+
del cls.decompressed_model
181+
del cls.compressed_model
182+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)