Skip to content

Commit 201d06c

Browse files
committed
Remove gpt_oss test code and add in examples
- Add gpt_oss_20b_example.py which does the convert and quantization - Clean up the gpt_oss.py from the test code Signed-off-by: Sharif Inamdar <[email protected]>
1 parent a30a86a commit 201d06c

File tree

2 files changed

+86
-51
lines changed

2 files changed

+86
-51
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
from compressed_tensors.quantization import QuantizationScheme
8+
from compressed_tensors.quantization.quant_args import (
9+
QuantizationArgs,
10+
QuantizationStrategy,
11+
QuantizationType,
12+
)
13+
14+
from llmcompressor.modeling.gpt_oss import convert_model_for_quantization_gptoss
15+
16+
17+
def main():
18+
MODEL_ID = "openai/gpt-oss-20b"
19+
BASE_NAME = MODEL_ID.rstrip("/").split("/")[-1]
20+
OUTPUT_DIR = f"{BASE_NAME}-w4a8-channelwise"
21+
22+
print(f"[GPT-OSS] Loading model: {MODEL_ID}")
23+
model = AutoModelForCausalLM.from_pretrained(
24+
MODEL_ID,
25+
torch_dtype=torch.bfloat16,
26+
device_map="auto",
27+
trust_remote_code=True,
28+
)
29+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30+
31+
# ---- GPT-OSS MoE → linear experts conversion ----
32+
print("[GPT-OSS] Converting fused MoE experts to SequentialGPTOSSMoE for quantization...")
33+
convert_model_for_quantization_gptoss(model)
34+
print("[GPT-OSS] Conversion completed.")
35+
36+
# ---- Quantization config: W4A8 (int4 weights, int8 activations) ----
37+
38+
# Weights: 4-bit, channelwise, symmetric, static
39+
weights_args = QuantizationArgs(
40+
num_bits=4,
41+
type=QuantizationType.INT,
42+
strategy=QuantizationStrategy.CHANNEL,
43+
symmetric=True,
44+
dynamic=False,
45+
)
46+
47+
# Activations: 8-bit, per-token, asymmetric, dynamic
48+
activations_args = QuantizationArgs(
49+
num_bits=8,
50+
type=QuantizationType.INT,
51+
strategy=QuantizationStrategy.TOKEN,
52+
symmetric=False,
53+
dynamic=True,
54+
observer=None,
55+
)
56+
57+
# Apply to all Linear layers, excluding lm_head
58+
scheme = QuantizationScheme(
59+
targets=["Linear"],
60+
weights=weights_args,
61+
input_activations=activations_args,
62+
)
63+
64+
recipe = QuantizationModifier(
65+
config_groups={"group_0": scheme},
66+
ignore=["lm_head"],
67+
)
68+
69+
print(f"[GPT-OSS] Starting oneshot quantization → {OUTPUT_DIR}")
70+
oneshot(
71+
model=model,
72+
recipe=recipe,
73+
tokenizer=tokenizer,
74+
output_dir=OUTPUT_DIR,
75+
trust_remote_code_model=True,
76+
)
77+
print(f"[GPT-OSS] Quantization finished. Quantized model written to: {OUTPUT_DIR}")
78+
79+
if __name__ == "__main__":
80+
main()

src/llmcompressor/modeling/gpt_oss.py

Lines changed: 6 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
import gc
22
import torch
33
from torch import nn
4-
import os
5-
from transformers import AutoModelForCausalLM, AutoTokenizer
64

7-
from llmcompressor import oneshot
85
from llmcompressor.utils.dev import skip_weights_initialize
9-
from llmcompressor.modifiers.quantization import QuantizationModifier
106
from compressed_tensors.utils import align_module_device, update_offload_parameter
117

128
def convert_model_for_quantization_gptoss(model):
@@ -49,8 +45,9 @@ def convert_model_for_quantization_gptoss(model):
4945
if to_delete:
5046
gc.collect()
5147
try:
52-
torch.cuda.synchronize()
53-
torch.cuda.empty_cache()
48+
if torch.cuda.is_available():
49+
torch.cuda.synchronize()
50+
torch.cuda.empty_cache()
5451
except Exception as e:
5552
print(f"[GPT-OSS] Warning: Failed to empty CUDA cache: {e}", flush=True)
5653

@@ -163,54 +160,12 @@ def forward(self, hidden_states):
163160
for j in range(self.top_k):
164161
idx = router_indices[:, j]
165162
w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1)
166-
unique_experts = torch.unique(idx)
167-
for e in unique_experts:
163+
for e in range(self.num_experts):
168164
mask = (idx == e)
169-
out[mask] += self.experts[e](x[mask]) * w[mask]
165+
if mask.any():
166+
out[mask] += self.experts[e](x[mask]) * w[mask]
170167

171168
out = out.view(B, T, H)
172169
router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder
173170
return out, router_scores
174171

175-
176-
model_id = "unsloth/gpt-oss-120b-BF16"
177-
178-
model = AutoModelForCausalLM.from_pretrained(
179-
model_id,
180-
torch_dtype=torch.bfloat16,
181-
device_map="auto",
182-
trust_remote_code=True,
183-
)
184-
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
185-
186-
convert_model_for_quantization_gptoss(model)
187-
188-
# -----------------------------
189-
# Quantization recipe
190-
# -----------------------------
191-
recipe = QuantizationModifier(
192-
targets="Linear",
193-
scheme="FP8_DYNAMIC",
194-
ignore=[
195-
"re:.*lm_head",
196-
"re:.*self_attn",
197-
"re:.*attn",
198-
"re:.*attention.*",
199-
"re:.*router",
200-
],
201-
)
202-
203-
SAVE_DIR = f"{model_id.split('/')[-1]}-FP8-Dynamic"
204-
205-
# Oneshot quantization
206-
oneshot(
207-
model=model,
208-
tokenizer=tokenizer,
209-
recipe=recipe,
210-
trust_remote_code_model=True,
211-
output_dir=SAVE_DIR,
212-
)
213-
214-
# Save compressed
215-
model.save_pretrained(SAVE_DIR, save_compressed=True)
216-
tokenizer.save_pretrained(SAVE_DIR)

0 commit comments

Comments
 (0)