Skip to content

Commit bcd81ad

Browse files
committed
fix: oQ3.5 estimate bpw/size correction for expert down_proj boost
1 parent 708ad21 commit bcd81ad

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

omlx/oq.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,30 @@ def _build_quant_plan(
495495
current_bpw = next_bpw
496496
break
497497

498+
# oQ3.5: mandatory expert down_proj 4-bit (Super Weights protection)
499+
if oq_level == 3.5:
500+
for path, shape in named_shapes.items():
501+
if path in boost_map:
502+
continue
503+
if not _is_routed_expert(path):
504+
continue
505+
if not any(p in path for p in ("down_proj", "w2")):
506+
continue
507+
cand_bits = base_bits + 1 # 3→4
508+
if cand_bits not in (2, 3, 4, 5, 6, 8):
509+
continue
510+
cand_gs = _gs_for_mode(cand_bits, _OQ_DEFAULT_GROUP_SIZE)
511+
cand_mode = _mode_for_bits(cand_bits)
512+
base_cost = _tensor_quantized_bytes(
513+
shape, base_bits, base_group_size, base_mode
514+
)
515+
cand_cost = _tensor_quantized_bytes(shape, cand_bits, cand_gs, cand_mode)
516+
delta = 8 * (cand_cost - base_cost)
517+
if delta > 0:
518+
boost_map[path] = {"bits": cand_bits, "group_size": cand_gs, "mode": cand_mode}
519+
total_bits_f += delta
520+
current_bpw = total_bits_f / total_params
521+
498522
candidates = []
499523
for path, shape in named_shapes.items():
500524
if path in boost_map:
@@ -726,6 +750,14 @@ def estimate_bpw_and_size(model_path: str, oq_level: int, group_size: int = 64)
726750

727751
effective_bpw = total_weighted_bits / max(total_params, 1)
728752

753+
# oQ3.5 correction: expert down_proj 3→4 bit not visible in pre-sanitize scan
754+
# (fused tensors like gate_up_proj don't have .weight suffix).
755+
# After sanitize, down_proj is ~31% of routed expert params → ~10% of total.
756+
# +1 bit for 10% of params ≈ +0.1 bpw.
757+
if oq_level == 3.5:
758+
effective_bpw += 0.3
759+
total_output_bytes = int(effective_bpw * total_params / 8)
760+
729761
source_total = sum(
730762
sf.stat().st_size for sf in source.glob("*.safetensors")
731763
)

0 commit comments

Comments
 (0)