-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathtest_quantize.py
More file actions
66 lines (51 loc) · 2.25 KB
/
test_quantize.py
File metadata and controls
66 lines (51 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
#
# Run test_model.py first
import os
import torch
from transformers import BitsAndBytesConfig, Qwen3MoeModel, set_seed
from qwen3_moe_fused.modular_qwen3_moe_fused import Qwen3MoeFusedModel
from qwen3_moe_fused.quantize.quantizer import patch_bnb_quantizer
from test_utils import get_rtol_atol
os.environ["AUTOTUNE_DISABLE"] = "1"
def main():
patch_bnb_quantizer()
model_dir = "./pretrained/qwen-moe-tiny"
model_quantized_dir = "./pretrained/qwen-moe-tiny-quantized"
model_fused_dir = "./pretrained/qwen-moe-tiny-fused"
model_fused_quantized_dir = "./pretrained/qwen-moe-tiny-fused-quantized"
device = "cuda"
dtype = torch.bfloat16
set_seed(42)
vocab_size = 151936
batch_size = 7
seq_len = 13
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
)
model_quantized = Qwen3MoeModel.from_pretrained(
model_dir, device_map=device, torch_dtype=dtype, quantization_config=bnb_config
)
model_quantized.save_pretrained(model_quantized_dir)
model_fused_quantized = Qwen3MoeFusedModel.from_pretrained(
model_fused_dir, device_map=device, torch_dtype=dtype, quantization_config=bnb_config
)
model_fused_quantized.save_pretrained(model_fused_quantized_dir)
model = Qwen3MoeModel.from_pretrained(model_dir, device_map=device)
model_quantized = Qwen3MoeModel.from_pretrained(model_quantized_dir, device_map=device)
model_fused_quantized = Qwen3MoeFusedModel.from_pretrained(model_fused_quantized_dir, device_map=device)
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device, dtype=torch.int32)
hidden = model(input_ids=input_ids).last_hidden_state
hidden_quantized = model_quantized(input_ids=input_ids).last_hidden_state
hidden_fused_quantized = model_fused_quantized(input_ids=input_ids).last_hidden_state
# print(hidden)
# print(hidden_quantized)
# print(hidden_fused_quantized)
print(get_rtol_atol(hidden_quantized, hidden))
print(get_rtol_atol(hidden_fused_quantized, hidden))
print(get_rtol_atol(hidden_fused_quantized, hidden_quantized))
if __name__ == "__main__":
main()