Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Notes:

Here we provide several recipes for Llama3 models. The relative accuracy loss of quantized model should be less than 1%.

> Note: You can also enable static quantization for KV cache by adding `--static_kv_dtype fp8` argument to `quantize.py`, or `--static_kv_dtype=fp8` argument to `run_quant.sh` and `run_benchmark.sh`.

#### Llama 3.1 8B MXFP8

AutoRound tuning helps improve the accuracy, `iters` and `nsamples` is higher than default.
Expand Down Expand Up @@ -121,7 +123,7 @@ CUDA_VISIBLE_DEVICES=0 python quantize.py \
--low_gpu_mem_usage \
--export_format auto_round \
--export_path llama3.1-8B-MXFP4-MXFP8 \
--tasks mmlu piqa hellaswag gsm8k \
--tasks mmlu_llama piqa hellaswag gsm8k_llama \
--eval_batch_size 32
```

Expand Down Expand Up @@ -219,8 +221,7 @@ CUDA_VISIBLE_DEVICES=0,1 bash run_benchmark.sh --model_path=Llama-3.1-70B-MXFP8

The script automatically:
- Detects available GPUs from `CUDA_VISIBLE_DEVICES` and sets `tensor_parallel_size` accordingly
- Handles different `add_bos_token` settings for different tasks (GSM8K requires `False`, others use `True`)
- Runs default tasks: `piqa,hellaswag,mmlu,gsm8k` with batch size 8
- Runs default tasks: `piqa,hellaswag,mmlu_llama,gsm8k_llama` with batch size 8
- Supports custom task selection and batch size adjustment


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,13 @@ def get_accuracy(model_name_or_path, tokenizer=None, eval_tasks="mmlu", limit=No
default=[],
help="[mix-precision] ensure that listed layers are using same data type for quantization"
)
parser.add_argument(
"--static_kv_dtype",
default=None,
type=str,
choices=["fp8", "float8_e4m3fn"],
help="Data type for static quantize key and value.",
)
parser.add_argument("--use_recipe", action="store_true", help="whether to use recipe to quantize model")
parser.add_argument("--recipe_file", type=str, default="recipes/Meta-Llama-3.1-8B-Instruct_6bits.json", help="path of recipe file")
parser.add_argument("--iters", default=200, type=int, help="iters for autoround.")
Expand Down Expand Up @@ -248,6 +255,7 @@ def load_recipe_results(file_path):
target_bits=args.target_bits,
options=args.options,
shared_layers=args.shared_layers,
static_kv_dtype=args.static_kv_dtype,
enable_torch_compile=args.enable_torch_compile,
low_gpu_mem_usage=args.low_gpu_mem_usage,
export_format=args.export_format,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TASKS="piqa,hellaswag,mmlu_llama,gsm8k_llama"
BATCH_SIZE=64
GPU_MEMORY_UTILIZATION=0.8
KV_CACHE_DTYPE="auto"

while [[ $# -gt 0 ]]; do
case $1 in
Expand All @@ -25,13 +26,24 @@ while [[ $# -gt 0 ]]; do
GPU_MEMORY_UTILIZATION="${1#*=}"
shift
;;
--static_kv_dtype=*)
KV_CACHE_DTYPE="${1#*=}"
shift
;;
*)
echo "Unknown parameter: $1"
exit 1
;;
esac
done

# for fp8 kv cache
if [[ "$KV_CACHE_DTYPE" == "fp8" ]]; then
export VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION=0
export VLLM_ATTENTION_BACKEND="FLASHINFER"
echo "Using FP8 for KV cache"
fi

# Validate required parameters
if [[ -z "$MODEL_PATH" ]]; then
echo "Usage: bash run_benchmark.sh --model_path=<path_to_quantized_model> [--tasks=<tasks>] [--batch_size=<size>]"
Expand Down Expand Up @@ -65,6 +77,7 @@ fi
# Set common environment variables
export VLLM_ENABLE_AR_EXT=1
export TORCH_COMPILE_DISABLE=1
export VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION=1

# Function to run evaluation for specific tasks
run_evaluation() {
Expand All @@ -75,11 +88,11 @@ run_evaluation() {
echo "Running evaluation for tasks: $tasks (add_bos_token=$add_bos_token)"

# Print the command being executed
local cmd="lm_eval --model vllm --model_args pretrained=\"$MODEL_PATH\",add_bos_token=$add_bos_token,tensor_parallel_size=$TENSOR_PARALLEL_SIZE,gpu_memory_utilization=$GPU_MEMORY_UTILIZATION,data_parallel_size=1,max_model_len=8192 --tasks $tasks --batch_size $BATCH_SIZE $extra_args"
local cmd="lm_eval --model vllm --model_args pretrained=\"$MODEL_PATH\",add_bos_token=$add_bos_token,tensor_parallel_size=$TENSOR_PARALLEL_SIZE,gpu_memory_utilization=$GPU_MEMORY_UTILIZATION,data_parallel_size=1,max_model_len=8192,kv_cache_dtype=${KV_CACHE_DTYPE} --tasks $tasks --batch_size $BATCH_SIZE $extra_args"
echo "Executing command: $cmd"

lm_eval --model vllm \
--model_args pretrained="$MODEL_PATH",add_bos_token=$add_bos_token,tensor_parallel_size=$TENSOR_PARALLEL_SIZE,gpu_memory_utilization=$GPU_MEMORY_UTILIZATION,data_parallel_size=1,max_model_len=8192 \
--model_args pretrained="$MODEL_PATH",add_bos_token=$add_bos_token,tensor_parallel_size=$TENSOR_PARALLEL_SIZE,gpu_memory_utilization=$GPU_MEMORY_UTILIZATION,data_parallel_size=1,max_model_len=8192,kv_cache_dtype=${KV_CACHE_DTYPE} \
--tasks $tasks \
--batch_size $BATCH_SIZE \
$extra_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Usage: CUDA_VISIBLE_DEVICES=0 bash run_quant.sh --topology=Llama-3.1-8B --dtype=mxfp8 --input_model=/models/Meta-Llama-3.1-8B-Instruct --output_model=Llama-3.1-8B-MXFP8

# Parse command line arguments
KV_CACHE_DTYPE="auto"
while [[ $# -gt 0 ]]; do
case $1 in
--topology=*)
Expand All @@ -21,6 +22,10 @@ while [[ $# -gt 0 ]]; do
OUTPUT_MODEL="${1#*=}"
shift
;;
--static_kv_dtype=*)
KV_CACHE_DTYPE="${1#*=}"
shift
;;
*)
echo "Unknown parameter: $1"
exit 1
Expand All @@ -43,7 +48,11 @@ echo " Input Model: $INPUT_MODEL"
echo " Output Model: $OUTPUT_MODEL"

# Set common parameters
COMMON_ARGS="--quantize --enable_torch_compile --low_gpu_mem_usage --export_format auto_round"
if [ "$KV_CACHE_DTYPE" = "auto" ]; then
COMMON_ARGS="--quantize --enable_torch_compile --low_gpu_mem_usage --export_format auto_round"
else
COMMON_ARGS="--quantize --enable_torch_compile --low_gpu_mem_usage --export_format auto_round --static_kv_dtype $KV_CACHE_DTYPE"
fi

case "$TOPOLOGY" in
"Llama-3.1-8B")
Expand Down