Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions examples/awq/regression_tests/llama3_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import argparse
import time

import torch
from compressed_tensors.offload import dispatch_model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--scheme", default="W4A16_ASYM")
parser.add_argument("--save-dir", default=None)
parser.add_argument("--num-samples", type=int, default=NUM_CALIBRATION_SAMPLES)
args = parser.parse_args()

num_samples = args.num_samples

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
Comment on lines +29 to +30
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check the actual code at the specified lines
echo "=== Checking llama3_awq.py lines 25-35 ==="
cat -n examples/awq/regression_tests/llama3_awq.py | sed -n '25,35p'

echo ""
echo "=== Checking qwen25_32b_awq.py ==="
if [ -f "examples/awq/regression_tests/qwen25_32b_awq.py" ]; then
  grep -n "AutoModelForCausalLM.from_pretrained" examples/awq/regression_tests/qwen25_32b_awq.py
else
  echo "File not found"
fi

echo ""
echo "=== Checking llama4_scout_awq.py ==="
if [ -f "examples/awq/regression_tests/llama4_scout_awq.py" ]; then
  grep -n "AutoModelForCausalLM.from_pretrained" examples/awq/regression_tests/llama4_scout_awq.py
else
  echo "File not found"
fi

Repository: vllm-project/llm-compressor

Length of output: 697


🌐 Web query:

AutoModelForCausalLM.from_pretrained dtype parameter torch_dtype hugging face transformers

💡 Result:

The torch_dtype parameter in AutoModelForCausalLM.from_pretrained specifies the PyTorch data type (dtype) for loading the model weights, overriding the default torch.float32. It accepts a torch.dtype (e.g., torch.float16, torch.bfloat16) or the string "auto". When "auto" is used: - Checks config.json for a torch_dtype or dtype entry. - If absent, infers from the dtype of the first floating-point weight in the checkpoint. Examples: import torch from transformers import AutoModelForCausalLM # Load with specific dtype model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=torch.bfloat16) # Auto-detect dtype model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype="auto") Note: torch_dtype primarily affects parameters; some buffers with explicit dtypes (e.g., torch.float32) remain unchanged. In recent versions (post-v4.XX), it is also referred to as dtype in examples, but torch_dtype remains valid. This parameter helps reduce memory usage (e.g., FP16/BF16) and is passed via kwargs to the underlying PreTrainedModel.from_pretrained method.

Citations:


Change dtype parameter to torch_dtype in model-loading calls.

Line 29 uses dtype="auto" but the Transformers API requires torch_dtype="auto". Using the incorrect parameter name causes it to be ignored, loading the model at default precision (float32) instead of auto-detecting the optimal dtype. This undermines the regression test by producing invalid memory and timing measurements.

The same issue exists in examples/awq/regression_tests/qwen25_32b_awq.py on line 29.

Suggested changes
-    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
+    model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")

Apply the same fix to qwen25_32b_awq.py.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/awq/regression_tests/llama3_awq.py` around lines 29 - 30, The
model-loading call using AutoModelForCausalLM.from_pretrained currently passes
dtype="auto" which is ignored; change the parameter name to torch_dtype="auto"
so the model is loaded with the correct precision (e.g., update the call in
llama3_awq.py where AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
is used and make the identical change in qwen25_32b_awq.py for its
AutoModelForCausalLM.from_pretrained call; keep the
AutoTokenizer.from_pretrained(..., trust_remote_code=True) calls unchanged.


ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{num_samples}]")
ds = ds.shuffle(seed=42)

def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}

ds = ds.map(preprocess)

def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)

ds = ds.map(tokenize, remove_columns=ds.column_names)

recipe = [
AWQModifier(
ignore=["lm_head"],
scheme=args.scheme,
targets=["Linear"],
duo_scaling="both",
),
]

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=num_samples,
)

elapsed_time = time.time() - start_time
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
print("Quantization Complete")
print(f"Time: {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")
print(f"Peak GPU Memory: {peak_memory_gb:.2f} GB")

print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

save_dir = args.save_dir or (
MODEL_ID.rstrip("/").split("/")[-1] + f"-{args.scheme}"
)
model.save_pretrained(save_dir, save_compressed=True)
tokenizer.save_pretrained(save_dir)
print(f"Model saved to {save_dir}")


if __name__ == "__main__":
main()
128 changes: 128 additions & 0 deletions examples/awq/regression_tests/llama4_scout_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse
import time

import torch
from datasets import load_dataset
from transformers import Llama4ForConditionalGeneration, Llama4Processor

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.awq.mappings import AWQMapping

MODEL_ID = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

DATASET_ID = "neuralmagic/calibration"
NUM_CALIBRATION_SAMPLES = 512
MAX_SEQUENCE_LENGTH = 8192


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--scheme", default="W4A16_ASYM")
parser.add_argument("--save-dir", default=None)
parser.add_argument("--num-samples", type=int, default=NUM_CALIBRATION_SAMPLES)
args = parser.parse_args()

num_samples = args.num_samples

model = Llama4ForConditionalGeneration.from_pretrained(MODEL_ID, dtype="auto")
processor = Llama4Processor.from_pretrained(MODEL_ID)

ds = load_dataset(
DATASET_ID, name="LLM", split=f"train[:{num_samples}]"
)

def preprocess_function(example):
messages = []
for message in example["messages"]:
messages.append(
{
"role": message["role"],
"content": [{"type": "text", "text": message["content"]}],
}
)

return processor.apply_chat_template(
messages,
return_tensors="pt",
padding=False,
truncation=True,
max_length=MAX_SEQUENCE_LENGTH,
tokenize=True,
add_special_tokens=False,
return_dict=True,
add_generation_prompt=False,
)

ds = ds.map(preprocess_function, batched=False, remove_columns=ds.column_names)

def data_collator(batch):
assert len(batch) == 1
return {
key: (
torch.tensor(value)
if key != "pixel_values"
else torch.tensor(value, dtype=torch.bfloat16).squeeze(0)
)
for key, value in batch[0].items()
}

# Llama-4-Scout has both vision_model and language_model sub-models,
# so mappings must be scoped to language_model to avoid dual matches.
# The main experts use a fused gate_up_proj (not Linear), so only
# shared_expert Linear layers are AWQ targets.
recipe = AWQModifier(
targets="Linear",
scheme=args.scheme,
ignore=[
"re:.*lm_head",
"re:.*self_attn",
"re:.*router",
"re:.*vision_model.*",
"re:.*multi_modal_projector.*",
"Llama4TextAttention",
],
mappings=[
AWQMapping(
"re:.*language_model.*post_attention_layernorm$",
[
"re:.*shared_expert.gate_proj$",
"re:.*shared_expert.up_proj$",
],
),
AWQMapping(
"re:.*shared_expert.up_proj$",
["re:.*shared_expert.down_proj$"],
),
],
)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=num_samples,
data_collator=data_collator,
sequential_targets=["Llama4TextMLP"],
)

elapsed_time = time.time() - start_time
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
print("Quantization Complete")
print(f"Time: {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")
print(f"Peak GPU Memory: {peak_memory_gb:.2f} GB")

save_dir = args.save_dir or (
MODEL_ID.rstrip("/").split("/")[-1] + f"-{args.scheme}"
)
model.save_pretrained(save_dir, save_compressed=True)
processor.save_pretrained(save_dir)
print(f"Model saved to {save_dir}")


if __name__ == "__main__":
main()
Comment on lines +19 to +128
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Run ruff format on this file before merge.

CI is already red on examples/awq/regression_tests/llama4_scout_awq.py due to ruff format --check, so this file needs a formatting pass to get the quality job green.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/awq/regression_tests/llama4_scout_awq.py` around lines 19 - 128, Run
ruff format on the file to fix the style errors reported by CI: run `ruff
format` (or your editor's ruff formatting) and stage the resulting changes;
focus on formatting inside the main() block and helper functions (main,
preprocess_function, data_collator, and the AWQModifier/AWQMapping recipe) so
the file passes `ruff format --check` and the CI quality job turns green. Ensure
you do not alter logic—only whitespace, line breaks, and import/style
formatting—then commit the formatted file.

125 changes: 125 additions & 0 deletions examples/awq/regression_tests/mixtral_awq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import argparse
import time

import torch
from compressed_tensors.offload import dispatch_model
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.awq import AWQModifier
from llmcompressor.modifiers.awq.mappings import AWQMapping

MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1"

DATASET_ID = "HuggingFaceH4/ultrachat_200k"
DATASET_SPLIT = "train_sft"
NUM_CALIBRATION_SAMPLES = 256
MAX_SEQUENCE_LENGTH = 512


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--scheme", default="W4A16_ASYM")
parser.add_argument("--save-dir", default=None)
parser.add_argument("--num-samples", type=int, default=NUM_CALIBRATION_SAMPLES)
args = parser.parse_args()

num_samples = args.num_samples

model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, dtype=torch.bfloat16, trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{num_samples}]")
ds = ds.shuffle(seed=42)

def preprocess(example):
return {
"text": tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
)
}

ds = ds.map(preprocess)

def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=MAX_SEQUENCE_LENGTH,
truncation=True,
add_special_tokens=False,
)

ds = ds.map(tokenize, remove_columns=ds.column_names)

# Mixtral uses w1/w2/w3 naming for expert layers instead of
# gate_proj/up_proj/down_proj, so we need custom mappings
recipe = [
AWQModifier(
ignore=[
"lm_head",
"re:.*block_sparse_moe.gate",
],
scheme=args.scheme,
targets=["Linear"],
duo_scaling="both",
mappings=[
AWQMapping(
"re:.*input_layernorm$",
["re:.*q_proj$", "re:.*k_proj$", "re:.*v_proj$"],
),
AWQMapping("re:.*v_proj$", ["re:.*o_proj$"]),
AWQMapping(
"re:.*post_attention_layernorm$",
[
"re:.*block_sparse_moe.experts.*.w1$",
"re:.*block_sparse_moe.experts.*.w3$",
],
),
AWQMapping("re:.*w3$", ["re:.*w2$"]),
],
),
]

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

oneshot(
model=model,
dataset=ds,
recipe=recipe,
max_seq_length=MAX_SEQUENCE_LENGTH,
num_calibration_samples=num_samples,
trust_remote_code_model=True,
)

elapsed_time = time.time() - start_time
peak_memory_gb = torch.cuda.max_memory_allocated() / (1024**3)
print("Quantization Complete")
print(f"Time: {elapsed_time / 60:.2f} minutes ({elapsed_time:.2f} seconds)")
print(f"Peak GPU Memory: {peak_memory_gb:.2f} GB")

print("\n\n")
print("========== SAMPLE GENERATION ==============")
dispatch_model(model)
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
model.device
)
output = model.generate(input_ids, max_new_tokens=100)
print(tokenizer.decode(output[0]))
print("==========================================\n\n")

save_dir = args.save_dir or (
MODEL_ID.rstrip("/").split("/")[-1] + f"-{args.scheme}"
)
model.save_pretrained(save_dir, save_compressed=True)
tokenizer.save_pretrained(save_dir)
print(f"Model saved to {save_dir}")


if __name__ == "__main__":
main()
Loading
Loading