Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/model_free_ptq/qwen3.5_int8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from llmcompressor import model_free_ptq

MODEL_ID = "Qwen/Qwen3.5-35B-A3B"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-W8A8-INT8"

# Apply W8A8 to the model
# Once quantized, the model is saved
# using compressed-tensors to the SAVE_DIR.
model_free_ptq(
model_stub=MODEL_ID,
save_directory=SAVE_DIR,
scheme="W8A8",
ignore=[
"lm_head",
"re:.*mlp.gate$",
"re:.*mlp.shared_expert_gate.*",
"re:.*norm.*",
"re:.*embed_tokens.*",
"re:.*visual.*",
"re:.*conv1d.*"
],
max_workers=15,
device="cuda:0",
)
61 changes: 61 additions & 0 deletions src/llmcompressor/entrypoints/model_free/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def process_file(
"""
assert not is_microscale_scheme(scheme), "Use `_process_file_microscale_scheme`"
tensors = load_file(file_path)
tensors = split_fused_moe_experts(tensors)

if converter is not None:
converter.process(tensors)
Expand Down Expand Up @@ -194,3 +195,63 @@ def process_file_microscale_scheme(
total_size = sum(tensor.nbytes for tensor in tensors.values())
weight_map = {key: os.path.basename(save_path) for key in tensors.keys()}
return total_size, weight_map


def split_fused_moe_experts(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Find fused MoE experts (with gate_up_proj/down_proj).
Split them from 3D tensors into individual 2D expert tensors.

Args:
tensors: Dictionary of loaded tensors from safetensors file

Returns:
New dictionary with split expert weights
"""
_tensors = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable _tensors is unconventional for a local variable. A more descriptive name like processed_tensors or new_tensors would improve readability. The _ prefix is usually reserved for internal/private variables by convention. This change would need to be applied throughout the function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename as fused_tensors


for name, tensor in tensors.items():
# Check if this is a MoE expert weight (3D tensor for experts)
if tensor.ndim == 3 and ("experts.gate_up_proj" in name or "experts.down_proj" in name):
# Get number of experts
num_experts = tensor.shape[0]

if "gate_up_proj" in name:
# gate_up_proj is typically [num_experts, 2*intermediate, hidden]
if tensor.shape[1] % 2 != 0:
print(f"Warning: gate_up_proj {name} has odd second dimension: {tensor.shape}")
continue

Comment on lines +222 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues here:

  1. When a gate_up_proj tensor has an odd second dimension, the code continues without adding the original tensor to the _tensors dictionary. This will cause the tensor to be dropped from the model, which is a significant bug.
  2. The use of print for logging is not ideal for a library. Please use a proper logger, like loguru, which is used elsewhere in the project. This would also require importing logger at the top of the file.

The print statements on lines 240 and 252 should also be converted to logger.info for consistency.

Suggested change
print(f"Warning: gate_up_proj {name} has odd second dimension: {tensor.shape}")
continue
logger.warning(f"Warning: gate_up_proj {name} has odd second dimension: {tensor.shape}, skipping split")
_tensors[name] = tensor
continue

hidden_size = tensor.shape[1] // 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable hidden_size is confusing. Based on the comment on line 220 (gate_up_proj is typically [num_experts, 2*intermediate, hidden]), tensor.shape[1] is 2 * intermediate_size. Therefore, this variable actually holds the intermediate_size. Renaming it would improve clarity. You will also need to update its usages on lines 232 and 233.

Suggested change
hidden_size = tensor.shape[1] // 2
intermediate_size = tensor.shape[1] // 2


# Split into individual experts
for expert_idx in range(num_experts):
expert_tensor = tensor[expert_idx] # [2*hidden, intermediate]

# Split gate and up projections
gate_proj = expert_tensor[:hidden_size, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be a little more readable to just do the full indexing at the same time instead of doing it piecemeal, these are all views so performance will be equivalent

up_proj = expert_tensor[hidden_size:, :]

# Create new key names
base_key = name.replace("mlp.experts.gate_up_proj", f"mlp.experts.{expert_idx}")
_tensors[base_key + ".gate_proj.weight"] = gate_proj
_tensors[base_key + ".up_proj.weight"] = up_proj
Comment on lines +236 to +238
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for generating new tensor names for the split gate_proj and up_proj experts is incorrect. It constructs the new key by appending to the original tensor name which already includes .weight, resulting in an invalid key like ...weight.gate_proj.weight. You should remove the .weight suffix from the original name before creating the new keys.

Suggested change
base_key = name.replace("mlp.experts.gate_up_proj", f"mlp.experts.{expert_idx}")
_tensors[base_key + ".gate_proj.weight"] = gate_proj
_tensors[base_key + ".up_proj.weight"] = up_proj
base_key = name.rsplit(".weight", 1)[0].replace("mlp.experts.gate_up_proj", f"mlp.experts.{expert_idx}")
_tensors[base_key + ".gate_proj.weight"] = gate_proj
_tensors[base_key + ".up_proj.weight"] = up_proj


print(f"Split {name} into {num_experts} experts")

elif "down_proj" in name:
# down_proj is typically [num_experts, hidden, intermediate]
# Split into individual experts
for expert_idx in range(num_experts):
down_proj = tensor[expert_idx] # [hidden, intermediate]

# Create new key name
new_key = name.replace("mlp.experts.down_proj", f"mlp.experts.{expert_idx}") + ".down_proj.weight"
_tensors[new_key] = down_proj
Comment on lines +249 to +250
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the gate_up_proj case, the key generation for down_proj is incorrect. It appends .down_proj.weight to a name that already ends with .weight. Please correct this by first removing the original .weight suffix.

Suggested change
new_key = name.replace("mlp.experts.down_proj", f"mlp.experts.{expert_idx}") + ".down_proj.weight"
_tensors[new_key] = down_proj
new_key = name.rsplit(".weight", 1)[0].replace("mlp.experts.down_proj", f"mlp.experts.{expert_idx}") + ".down_proj.weight"
_tensors[new_key] = down_proj


print(f"Split {name} into {num_experts} experts")
else:
# Non-MoE or non-3D tensors, keep as is
_tensors[name] = tensor

return _tensors
Loading