|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from dataclasses import dataclass |
| 4 | +from typing import List, Optional |
| 5 | + |
| 6 | +import torch |
| 7 | +import torch.nn as nn |
| 8 | + |
| 9 | + |
| 10 | +class LinearExpert(nn.Module): |
| 11 | + """ |
| 12 | + One MoE expert with separate gate / up / down projections. |
| 13 | +
|
| 14 | + This mirrors the GPT-OSS expert behavior: |
| 15 | + gate = clamp(gate_proj(x)) |
| 16 | + up = clamp(up_proj(x)) |
| 17 | + glu = gate * sigmoid(alpha * gate) |
| 18 | + y = down_proj((up + 1) * glu) |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, hidden_size: int, intermediate_size: int, alpha: float, limit: float): |
| 22 | + super().__init__() |
| 23 | + self.alpha = alpha |
| 24 | + self.limit = limit |
| 25 | + |
| 26 | + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True) |
| 27 | + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True) |
| 28 | + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True) |
| 29 | + |
| 30 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 31 | + gate = self.gate_proj(x) |
| 32 | + up = self.up_proj(x) |
| 33 | + |
| 34 | + gate = gate.clamp(max=self.limit) |
| 35 | + up = up.clamp(min=-self.limit, max=self.limit) |
| 36 | + |
| 37 | + glu = gate * torch.sigmoid(self.alpha * gate) |
| 38 | + act = (up + 1) * glu |
| 39 | + return self.down_proj(act) |
| 40 | + |
| 41 | + |
| 42 | +class LinearExperts(nn.Module): |
| 43 | + """ |
| 44 | + Container of multiple LinearExpert modules, driven by router_indices / routing_weights. |
| 45 | +
|
| 46 | + This is the "separate gate/up" layout. |
| 47 | + It is meant to replace the original GPT-OSS `experts` submodule. |
| 48 | + """ |
| 49 | + |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + hidden_size: int, |
| 53 | + intermediate_size: int, |
| 54 | + num_experts: int, |
| 55 | + alpha: float = 1.702, |
| 56 | + limit: float = 7.0, |
| 57 | + ): |
| 58 | + super().__init__() |
| 59 | + self.hidden_size = hidden_size |
| 60 | + self.expert_dim = intermediate_size |
| 61 | + self.num_experts = num_experts |
| 62 | + self.alpha = alpha |
| 63 | + self.limit = limit |
| 64 | + |
| 65 | + self.experts = nn.ModuleList( |
| 66 | + [LinearExpert(hidden_size, intermediate_size, alpha, limit) for _ in range(num_experts)] |
| 67 | + ) |
| 68 | + |
| 69 | + @torch.no_grad() |
| 70 | + def copy_from_fused_weights( |
| 71 | + self, |
| 72 | + legacy_gate_up_W: torch.Tensor, # [E, H, 2D] |
| 73 | + legacy_gate_up_b: torch.Tensor, # [E, 2D] |
| 74 | + legacy_down_W: torch.Tensor, # [E, D, H] |
| 75 | + legacy_down_b: torch.Tensor, # [E, H] |
| 76 | + ) -> None: |
| 77 | + """ |
| 78 | + De-interleave fused gate_up weights/bias and copy into separate gate/up experts. |
| 79 | + """ |
| 80 | + E, H, twoD = legacy_gate_up_W.shape |
| 81 | + assert E == self.num_experts |
| 82 | + D = twoD // 2 |
| 83 | + assert D == self.expert_dim |
| 84 | + |
| 85 | + for i in range(E): |
| 86 | + Wi = legacy_gate_up_W[i] # [H, 2D] |
| 87 | + bi = legacy_gate_up_b[i] # [2D] |
| 88 | + |
| 89 | + Wg = Wi[:, 0::2].contiguous() # [H, D] |
| 90 | + Wu = Wi[:, 1::2].contiguous() # [H, D] |
| 91 | + bg = bi[0::2].contiguous() # [D] |
| 92 | + bu = bi[1::2].contiguous() # [D] |
| 93 | + |
| 94 | + expert = self.experts[i] |
| 95 | + expert.gate_proj.weight.copy_(Wg.t()) |
| 96 | + expert.gate_proj.bias.copy_(bg) |
| 97 | + expert.up_proj.weight.copy_(Wu.t()) |
| 98 | + expert.up_proj.bias.copy_(bu) |
| 99 | + |
| 100 | + expert.down_proj.weight.copy_(legacy_down_W[i].t()) |
| 101 | + expert.down_proj.bias.copy_(legacy_down_b[i]) |
| 102 | + |
| 103 | + def forward( |
| 104 | + self, |
| 105 | + hidden_states: torch.Tensor, # [B, T, H] |
| 106 | + router_indices: Optional[torch.Tensor] = None, # [B, T, top_k] or [tokens, top_k] |
| 107 | + routing_weights: Optional[torch.Tensor] = None, # [B, T, E] or [tokens, E] |
| 108 | + ) -> torch.Tensor: |
| 109 | + """ |
| 110 | + Implements the MoE computation using the router outputs. |
| 111 | +
|
| 112 | + This is compatible with the GPT-OSS MoE call pattern: |
| 113 | + experts(hidden_states, router_indices, routing_weights) |
| 114 | + """ |
| 115 | + assert routing_weights is not None and router_indices is not None, "router inputs required" |
| 116 | + |
| 117 | + # Normalize shapes to [tokens, H], [tokens, top_k], [tokens, E] |
| 118 | + if hidden_states.dim() == 3: |
| 119 | + B, T, H = hidden_states.shape |
| 120 | + x = hidden_states.reshape(-1, H) |
| 121 | + else: |
| 122 | + # Already flattened |
| 123 | + B, T = 1, hidden_states.shape[0] |
| 124 | + H = hidden_states.shape[-1] |
| 125 | + x = hidden_states |
| 126 | + |
| 127 | + if router_indices.dim() == 3: |
| 128 | + router_indices = router_indices.reshape(-1, router_indices.shape[-1]) |
| 129 | + if routing_weights.dim() == 3: |
| 130 | + routing_weights = routing_weights.reshape(-1, routing_weights.shape[-1]) |
| 131 | + |
| 132 | + num_experts_plus_dummy = routing_weights.shape[1] |
| 133 | + out = torch.zeros_like(x) |
| 134 | + |
| 135 | + # GPT-OSS router uses an extra "no expert" bucket at index E |
| 136 | + with torch.no_grad(): |
| 137 | + expert_mask = torch.nn.functional.one_hot( |
| 138 | + router_indices, num_classes=num_experts_plus_dummy |
| 139 | + ).permute(2, 1, 0) |
| 140 | + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() |
| 141 | + |
| 142 | + for idx in expert_hit: |
| 143 | + e = idx[0].item() |
| 144 | + if e == self.num_experts: |
| 145 | + # Skip "no expert" bucket |
| 146 | + continue |
| 147 | + |
| 148 | + _, token_idx = torch.where(expert_mask[e]) |
| 149 | + xi = x[token_idx] |
| 150 | + |
| 151 | + expert = self.experts[e] |
| 152 | + yi = expert(xi) |
| 153 | + |
| 154 | + w = routing_weights[token_idx, e, None] |
| 155 | + out.index_add_(0, token_idx, (yi * w).to(out.dtype)) |
| 156 | + |
| 157 | + return out.view(B, -1, H) |
| 158 | + |
| 159 | + |
| 160 | +@dataclass |
| 161 | +class ExpertMeta: |
| 162 | + path: str |
| 163 | + hidden_size: int |
| 164 | + intermediate_size: int |
| 165 | + num_experts: int |
| 166 | + device: torch.device |
| 167 | + dtype: torch.dtype |
| 168 | + |
| 169 | + |
| 170 | +def get_module_by_path(root: nn.Module, dotpath: str) -> nn.Module: |
| 171 | + m: nn.Module = root |
| 172 | + for p in dotpath.split("."): |
| 173 | + m = getattr(m, p) |
| 174 | + return m |
| 175 | + |
| 176 | + |
| 177 | +def set_module_by_path(root: nn.Module, dotpath: str, new_module: nn.Module) -> None: |
| 178 | + parts = dotpath.split(".") |
| 179 | + parent = get_module_by_path(root, ".".join(parts[:-1])) |
| 180 | + setattr(parent, parts[-1], new_module) |
| 181 | + |
| 182 | + |
| 183 | +def find_experts(model: nn.Module) -> List[ExpertMeta]: |
| 184 | + """ |
| 185 | + Locate GPT-OSS MoE expert modules under model.model.layers[*].mlp.experts. |
| 186 | + """ |
| 187 | + metas: List[ExpertMeta] = [] |
| 188 | + for li, layer in enumerate(model.model.layers): |
| 189 | + experts = layer.mlp.experts |
| 190 | + device = next(experts.parameters(), torch.zeros(())).device |
| 191 | + dtype = next(experts.parameters(), torch.zeros(())).dtype |
| 192 | + intermediate = getattr(experts, "expert_dim", None) or getattr( |
| 193 | + experts, "intermediate_size" |
| 194 | + ) |
| 195 | + |
| 196 | + metas.append( |
| 197 | + ExpertMeta( |
| 198 | + path=f"model.layers.{li}.mlp.experts", |
| 199 | + hidden_size=experts.hidden_size, |
| 200 | + intermediate_size=intermediate, |
| 201 | + num_experts=experts.num_experts, |
| 202 | + device=device, |
| 203 | + dtype=dtype, |
| 204 | + ) |
| 205 | + ) |
| 206 | + return metas |
| 207 | + |
| 208 | + |
| 209 | +def convert_model_for_quantization_gptoss(model: nn.Module) -> None: |
| 210 | + """ |
| 211 | + In-place conversion of a GPT-OSS model: |
| 212 | +
|
| 213 | + - Finds all fused MoE expert blocks (with gate_up_proj/down_proj). |
| 214 | + - Replaces them with LinearExperts that expose plain nn.Linear |
| 215 | + parameters (gate_proj, up_proj, down_proj), which play nicely |
| 216 | + with LLM Compressor W4A8 quantization. |
| 217 | + """ |
| 218 | + metas = find_experts(model) |
| 219 | + for meta in metas: |
| 220 | + legacy = get_module_by_path(model, meta.path) |
| 221 | + |
| 222 | + # Sanity check that this is the fused layout we expect. |
| 223 | + if not hasattr(legacy, "gate_up_proj") or not hasattr(legacy, "down_proj"): |
| 224 | + continue |
| 225 | + |
| 226 | + new_exp = LinearExperts( |
| 227 | + hidden_size=meta.hidden_size, |
| 228 | + intermediate_size=meta.intermediate_size, |
| 229 | + num_experts=meta.num_experts, |
| 230 | + ).to(device=meta.device, dtype=meta.dtype) |
| 231 | + |
| 232 | + new_exp.copy_from_fused_weights( |
| 233 | + legacy_gate_up_W=legacy.gate_up_proj, |
| 234 | + legacy_gate_up_b=legacy.gate_up_proj_bias, |
| 235 | + legacy_down_W=legacy.down_proj, |
| 236 | + legacy_down_b=legacy.down_proj_bias, |
| 237 | + ) |
| 238 | + |
| 239 | + set_module_by_path(model, meta.path, new_exp) |
| 240 | + |
| 241 | + print("[GPT-OSS] Converted fused MoE experts to LinearExperts for quantization.") |
0 commit comments