Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions examples/awq/fp8_block_llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down Expand Up @@ -49,9 +50,19 @@ def tokenize(sample):


# Configure the quantization algorithm to run.
# AWQModifier performs smoothing and must be followed by a QuantizationModifier
# which applies the actual quantization.
recipe = [
AWQModifier(
ignore=["lm_head"], scheme="FP8_BLOCK", targets=["Linear"], duo_scaling="both"
ignore=["lm_head"],
scheme="FP8_BLOCK",
targets=["Linear"],
duo_scaling="both",
),
QuantizationModifier(
targets="Linear",
scheme="FP8_BLOCK",
ignore=["lm_head"],
),
]

Expand All @@ -76,6 +87,6 @@ def tokenize(sample):
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-asym"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-fp8-block"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
15 changes: 13 additions & 2 deletions examples/awq/fp8_dynamic_llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down Expand Up @@ -49,9 +50,19 @@ def tokenize(sample):


# Configure the quantization algorithm to run.
# AWQModifier performs smoothing and must be followed by a QuantizationModifier
# which applies the actual quantization.
recipe = [
AWQModifier(
ignore=["lm_head"], scheme="FP8_DYNAMIC", targets=["Linear"], duo_scaling="both"
ignore=["lm_head"],
scheme="FP8_DYNAMIC",
targets=["Linear"],
duo_scaling="both",
),
QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["lm_head"],
),
]

Expand All @@ -76,6 +87,6 @@ def tokenize(sample):
print("==========================================\n\n")

# Save to disk compressed.
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-asym"
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-awq-fp8-dynamic"
model.save_pretrained(SAVE_DIR, save_compressed=True)
tokenizer.save_pretrained(SAVE_DIR)
13 changes: 12 additions & 1 deletion examples/awq/llama_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down Expand Up @@ -49,9 +50,19 @@ def tokenize(sample):


# Configure the quantization algorithm to run.
# AWQModifier performs smoothing and must be followed by a QuantizationModifier
# which applies the actual quantization.
recipe = [
AWQModifier(
ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"], duo_scaling="both"
ignore=["lm_head"],
scheme="W4A16_ASYM",
targets=["Linear"],
duo_scaling="both",
),
QuantizationModifier(
targets="Linear",
scheme="W4A16_ASYM",
ignore=["lm_head"],
),
]

Expand Down
13 changes: 12 additions & 1 deletion examples/awq/llama_example_with_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand Down Expand Up @@ -105,9 +106,19 @@ def tokenize(sample):
ds = ds.map(tokenize, remove_columns=ds.column_names)

# Configure the quantization algorithm to run.
# AWQModifier performs smoothing and must be followed by a QuantizationModifier
# which applies the actual quantization.
recipe = [
AWQModifier(
ignore=["lm_head"], scheme="W4A16_ASYM", targets=["Linear"], duo_scaling="both"
ignore=["lm_head"],
scheme="W4A16_ASYM",
targets=["Linear"],
duo_scaling="both",
),
QuantizationModifier(
targets="Linear",
scheme="W4A16_ASYM",
ignore=["lm_head"],
),
]

Expand Down
8 changes: 8 additions & 0 deletions examples/awq/qwen3_coder_moe_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,26 @@

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier

MODEL_ID = "Qwen/Qwen3-Coder-30B-A3B-Instruct"
SAVE_DIR = MODEL_ID.split("/")[-1] + "-W4A16-awq"

# Configure the quantization algorithm to run.
# AWQModifier performs smoothing and must be followed by a QuantizationModifier
# which applies the actual quantization.
recipe = [
AWQModifier(
duo_scaling=False,
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
scheme="W4A16",
targets=["Linear"],
),
QuantizationModifier(
targets="Linear",
scheme="W4A16",
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
),
]

# Select calibration dataset.
Expand Down
8 changes: 8 additions & 0 deletions examples/awq/qwen3_moe_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.quantization.quantization import QuantizationModifier

# Select model and load it.
MODEL_ID = "Qwen/Qwen3-30B-A3B"
Expand Down Expand Up @@ -49,13 +50,20 @@ def tokenize(sample):


# Configure the quantization algorithm to run.
# AWQModifier performs smoothing and must be followed by a QuantizationModifier
# which applies the actual quantization.
# NOTE: vllm currently does not support asym MoE, using symmetric here
recipe = [
AWQModifier(
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
scheme="W4A16",
targets=["Linear"],
),
QuantizationModifier(
targets="Linear",
scheme="W4A16",
ignore=["lm_head", "re:.*mlp.gate$", "re:.*mlp.shared_expert_gate$"],
),
]

# Apply algorithms.
Expand Down
Loading