Skip to content

Commit e521436

Browse files
committed
[CPU] Linearize gpt_oss model and add separate example to quantize it to w4a8
Signed-off-by: Sharif Inamdar <[email protected]>
1 parent db0b68d commit e521436

File tree

2 files changed

+322
-0
lines changed

2 files changed

+322
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from llmcompressor import oneshot
5+
from llmcompressor.modifiers.quantization import QuantizationModifier
6+
7+
from compressed_tensors.quantization import QuantizationScheme
8+
from compressed_tensors.quantization.quant_args import (
9+
QuantizationArgs,
10+
QuantizationStrategy,
11+
QuantizationType,
12+
)
13+
14+
from llmcompressor.modeling.gpt_oss import convert_model_for_quantization_gptoss
15+
16+
17+
def main():
18+
MODEL_ID = "openai/gpt-oss-20b"
19+
BASE_NAME = MODEL_ID.rstrip("/").split("/")[-1]
20+
OUTPUT_DIR = f"{BASE_NAME}-w4a8-channelwise"
21+
22+
print(f"[GPT-OSS] Loading model: {MODEL_ID}")
23+
model = AutoModelForCausalLM.from_pretrained(
24+
MODEL_ID,
25+
torch_dtype=torch.bfloat16,
26+
device_map="auto",
27+
trust_remote_code=True,
28+
)
29+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
30+
31+
# ---- GPT-OSS MoE → linear experts conversion ----
32+
print("[GPT-OSS] Converting fused MoE experts to LinearExperts for quantization...")
33+
convert_model_for_quantization_gptoss(model)
34+
print("[GPT-OSS] Conversion completed.")
35+
36+
# ---- Quantization config: W4A8 (int4 weights, int8 activations) ----
37+
38+
# Weights: 4-bit, channelwise, symmetric, static
39+
weights_args = QuantizationArgs(
40+
num_bits=4,
41+
type=QuantizationType.INT,
42+
strategy=QuantizationStrategy.CHANNEL,
43+
symmetric=True,
44+
dynamic=False,
45+
)
46+
47+
# Activations: 8-bit, per-token, asymmetric, dynamic
48+
activations_args = QuantizationArgs(
49+
num_bits=8,
50+
type=QuantizationType.INT,
51+
strategy=QuantizationStrategy.TOKEN,
52+
symmetric=False,
53+
dynamic=True,
54+
observer=None,
55+
)
56+
57+
# Apply to all Linear layers, excluding lm_head
58+
scheme = QuantizationScheme(
59+
targets=["Linear"],
60+
weights=weights_args,
61+
input_activations=activations_args,
62+
)
63+
64+
recipe = QuantizationModifier(
65+
config_groups={"group_0": scheme},
66+
ignore=["lm_head"],
67+
)
68+
69+
print(f"[GPT-OSS] Starting oneshot quantization → {OUTPUT_DIR}")
70+
oneshot(
71+
model=model,
72+
recipe=recipe,
73+
tokenizer=tokenizer,
74+
output_dir=OUTPUT_DIR,
75+
trust_remote_code_model=True,
76+
)
77+
print(f"[GPT-OSS] Quantization finished. Quantized model written to: {OUTPUT_DIR}")
78+
79+
if __name__ == "__main__":
80+
main()
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import List, Optional
5+
6+
import torch
7+
import torch.nn as nn
8+
9+
10+
class LinearExpert(nn.Module):
11+
"""
12+
One MoE expert with separate gate / up / down projections.
13+
14+
This mirrors the GPT-OSS expert behavior:
15+
gate = clamp(gate_proj(x))
16+
up = clamp(up_proj(x))
17+
glu = gate * sigmoid(alpha * gate)
18+
y = down_proj((up + 1) * glu)
19+
"""
20+
21+
def __init__(self, hidden_size: int, intermediate_size: int, alpha: float, limit: float):
22+
super().__init__()
23+
self.alpha = alpha
24+
self.limit = limit
25+
26+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
27+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True)
28+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True)
29+
30+
def forward(self, x: torch.Tensor) -> torch.Tensor:
31+
gate = self.gate_proj(x)
32+
up = self.up_proj(x)
33+
34+
gate = gate.clamp(max=self.limit)
35+
up = up.clamp(min=-self.limit, max=self.limit)
36+
37+
glu = gate * torch.sigmoid(self.alpha * gate)
38+
act = (up + 1) * glu
39+
return self.down_proj(act)
40+
41+
42+
class LinearExperts(nn.Module):
43+
"""
44+
Container of multiple LinearExpert modules, driven by router_indices / routing_weights.
45+
46+
This is the "separate gate/up" layout.
47+
It is meant to replace the original GPT-OSS `experts` submodule.
48+
"""
49+
50+
def __init__(
51+
self,
52+
hidden_size: int,
53+
intermediate_size: int,
54+
num_experts: int,
55+
alpha: float = 1.702,
56+
limit: float = 7.0,
57+
):
58+
super().__init__()
59+
self.hidden_size = hidden_size
60+
self.expert_dim = intermediate_size
61+
self.num_experts = num_experts
62+
self.alpha = alpha
63+
self.limit = limit
64+
65+
self.experts = nn.ModuleList(
66+
[LinearExpert(hidden_size, intermediate_size, alpha, limit) for _ in range(num_experts)]
67+
)
68+
69+
@torch.no_grad()
70+
def copy_from_fused_weights(
71+
self,
72+
legacy_gate_up_W: torch.Tensor, # [E, H, 2D]
73+
legacy_gate_up_b: torch.Tensor, # [E, 2D]
74+
legacy_down_W: torch.Tensor, # [E, D, H]
75+
legacy_down_b: torch.Tensor, # [E, H]
76+
) -> None:
77+
"""
78+
De-interleave fused gate_up weights/bias and copy into separate gate/up experts.
79+
"""
80+
E, H, twoD = legacy_gate_up_W.shape
81+
assert E == self.num_experts
82+
D = twoD // 2
83+
assert D == self.expert_dim
84+
85+
for i in range(E):
86+
Wi = legacy_gate_up_W[i] # [H, 2D]
87+
bi = legacy_gate_up_b[i] # [2D]
88+
89+
Wg = Wi[:, 0::2].contiguous() # [H, D]
90+
Wu = Wi[:, 1::2].contiguous() # [H, D]
91+
bg = bi[0::2].contiguous() # [D]
92+
bu = bi[1::2].contiguous() # [D]
93+
94+
expert = self.experts[i]
95+
expert.gate_proj.weight.copy_(Wg.t())
96+
expert.gate_proj.bias.copy_(bg)
97+
expert.up_proj.weight.copy_(Wu.t())
98+
expert.up_proj.bias.copy_(bu)
99+
100+
expert.down_proj.weight.copy_(legacy_down_W[i].t())
101+
expert.down_proj.bias.copy_(legacy_down_b[i])
102+
103+
def forward(
104+
self,
105+
hidden_states: torch.Tensor, # [B, T, H]
106+
router_indices: Optional[torch.Tensor] = None, # [B, T, top_k] or [tokens, top_k]
107+
routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E]
108+
) -> torch.Tensor:
109+
"""
110+
Implements the MoE computation using the router outputs.
111+
112+
This is compatible with the GPT-OSS MoE call pattern:
113+
experts(hidden_states, router_indices, routing_weights)
114+
"""
115+
assert routing_weights is not None and router_indices is not None, "router inputs required"
116+
117+
# Normalize shapes to [tokens, H], [tokens, top_k], [tokens, E]
118+
if hidden_states.dim() == 3:
119+
B, T, H = hidden_states.shape
120+
x = hidden_states.reshape(-1, H)
121+
else:
122+
# Already flattened
123+
B, T = 1, hidden_states.shape[0]
124+
H = hidden_states.shape[-1]
125+
x = hidden_states
126+
127+
if router_indices.dim() == 3:
128+
router_indices = router_indices.reshape(-1, router_indices.shape[-1])
129+
if routing_weights.dim() == 3:
130+
routing_weights = routing_weights.reshape(-1, routing_weights.shape[-1])
131+
132+
num_experts_plus_dummy = routing_weights.shape[1]
133+
out = torch.zeros_like(x)
134+
135+
# GPT-OSS router uses an extra "no expert" bucket at index E
136+
with torch.no_grad():
137+
expert_mask = torch.nn.functional.one_hot(
138+
router_indices, num_classes=num_experts_plus_dummy
139+
).permute(2, 1, 0)
140+
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
141+
142+
for idx in expert_hit:
143+
e = idx[0].item()
144+
if e == self.num_experts:
145+
# Skip "no expert" bucket
146+
continue
147+
148+
_, token_idx = torch.where(expert_mask[e])
149+
xi = x[token_idx]
150+
151+
expert = self.experts[e]
152+
yi = expert(xi)
153+
154+
w = routing_weights[token_idx, e, None]
155+
out.index_add_(0, token_idx, (yi * w).to(out.dtype))
156+
157+
return out.view(B, -1, H)
158+
159+
160+
@dataclass
161+
class ExpertMeta:
162+
path: str
163+
hidden_size: int
164+
intermediate_size: int
165+
num_experts: int
166+
device: torch.device
167+
dtype: torch.dtype
168+
169+
170+
def get_module_by_path(root: nn.Module, dotpath: str) -> nn.Module:
171+
m: nn.Module = root
172+
for p in dotpath.split("."):
173+
m = getattr(m, p)
174+
return m
175+
176+
177+
def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> None:
178+
parts = dotpath.split(".")
179+
parent = get_module_by_path(root, ".".join(parts[:-1]))
180+
setattr(parent, parts[-1], new_module)
181+
182+
183+
def find_experts(model: nn.Module) -> List[ExpertMeta]:
184+
"""
185+
Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts.
186+
"""
187+
metas: List[ExpertMeta] = []
188+
for li, layer in enumerate(model.model.layers):
189+
experts = layer.mlp.experts
190+
device = next(experts.parameters(), torch.zeros(())).device
191+
dtype = next(experts.parameters(), torch.zeros(())).dtype
192+
intermediate = getattr(experts, "expert_dim", None)
193+
if intermediate is None:
194+
intermediate = getattr(experts, "intermediate_size")
195+
196+
metas.append(
197+
ExpertMeta(
198+
path=f"model.layers.{li}.mlp.experts",
199+
hidden_size=experts.hidden_size,
200+
intermediate_size=intermediate,
201+
num_experts=experts.num_experts,
202+
device=device,
203+
dtype=dtype,
204+
)
205+
)
206+
return metas
207+
208+
209+
def convert_model_for_quantization_gptoss(model: nn.Module) -> None:
210+
"""
211+
In-place conversion of a GPT-OSS model:
212+
213+
- Finds all fused MoE expert blocks (with gate_up_proj/down_proj).
214+
- Replaces them with LinearExperts that expose plain nn.Linear
215+
parameters (gate_proj, up_proj, down_proj), which play nicely
216+
with LLM Compressor W4A8 quantization.
217+
"""
218+
metas = find_experts(model)
219+
for meta in metas:
220+
legacy = get_module_by_path(model, meta.path)
221+
222+
# Sanity check that this is the fused layout we expect.
223+
if not all(hasattr(legacy, attr) for attr in ["gate_up_proj",
224+
"gate_up_proj_bias",
225+
"down_proj",
226+
"down_proj_bias"]):
227+
continue
228+
229+
new_exp = LinearExperts(
230+
hidden_size=meta.hidden_size,
231+
intermediate_size=meta.intermediate_size,
232+
num_experts=meta.num_experts,
233+
).to(device=meta.device, dtype=meta.dtype)
234+
235+
new_exp.copy_from_fused_weights(
236+
legacy_gate_up_W=legacy.gate_up_proj,
237+
legacy_gate_up_b=legacy.gate_up_proj_bias,
238+
legacy_down_W=legacy.down_proj,
239+
legacy_down_b=legacy.down_proj_bias,
240+
)
241+
242+
set_module_by_path(model, meta.path, new_exp)

0 commit comments

Comments
 (0)