Skip to content

Commit 860303b

Browse files
mvanhornclaude
andcommitted
feat: add --dry-run estimation mode
Add a --dry-run flag to the CLI that estimates VRAM usage, output file size, and approximate quantization time without running the full quantization process. Uses AutoConfig to load model architecture metadata without downloading weights. New module: auto_round/estimation.py with estimation functions for parameter count, peak VRAM, output size, and time. Relates to #1551 and #1584 Fixes #1591 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
1 parent 79fa1a9 commit 860303b

File tree

3 files changed

+445
-0
lines changed

3 files changed

+445
-0
lines changed

auto_round/__main__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,14 @@ def __init__(self, *args, **kwargs):
168168
basic.add_argument(
169169
"--enable_torch_compile", action="store_true", help="Enable PyTorch compilation for faster execution. "
170170
)
171+
basic.add_argument(
172+
"--dry_run",
173+
"--dry-run",
174+
action="store_true",
175+
help="Estimate VRAM usage, output file size, and quantization time "
176+
"without running the full quantization process. "
177+
"Loads only the model config (no weights) and prints a summary.",
178+
)
171179
basic.add_argument(
172180
"--disable_trust_remote_code",
173181
action="store_true",
@@ -613,6 +621,30 @@ def tune(args):
613621
scheme = args.scheme.upper()
614622
if scheme not in PRESET_SCHEMES:
615623
raise ValueError(f"{scheme} is not supported. only {PRESET_SCHEMES.keys()} are supported ")
624+
625+
if args.dry_run:
626+
from auto_round.estimation import dry_run_estimate, print_dry_run_report
627+
628+
scheme_obj = PRESET_SCHEMES[scheme]
629+
target_bits = args.bits if args.bits is not None else scheme_obj.bits
630+
group_size = args.group_size if args.group_size is not None else scheme_obj.group_size
631+
632+
model_dtype = args.model_dtype or "float16"
633+
estimates = dry_run_estimate(
634+
model_name=model_name,
635+
scheme_bits=target_bits,
636+
group_size=group_size,
637+
model_dtype=model_dtype,
638+
batch_size=args.batch_size,
639+
seqlen=args.seqlen,
640+
nsamples=args.nsamples,
641+
iters=args.iters,
642+
trust_remote_code=not args.disable_trust_remote_code,
643+
platform=args.platform,
644+
)
645+
print_dry_run_report(estimates)
646+
return
647+
616648
if args.disable_deterministic_algorithms:
617649
logger.warning(
618650
"default not use deterministic_algorithms. disable_deterministic_algorithms is deprecated,"

auto_round/estimation.py

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Dry-run estimation utilities for AutoRound.
16+
17+
Estimates VRAM usage, output file size, and approximate quantization time
18+
from model configuration metadata without loading model weights.
19+
"""
20+
21+
import math
22+
23+
from auto_round.logger import logger
24+
25+
DTYPE_BYTES = {
26+
"float32": 4,
27+
"fp32": 4,
28+
"float16": 2,
29+
"fp16": 2,
30+
"bfloat16": 2,
31+
"bf16": 2,
32+
"float8_e4m3fn": 1,
33+
"fp8": 1,
34+
}
35+
36+
# Rough seconds per layer per iteration, measured on A100 for a 7B-class model.
37+
# Actual speed varies widely by hardware and model architecture.
38+
_SECS_PER_LAYER_PER_ITER = 0.12
39+
40+
41+
def _count_parameters(config):
42+
"""Estimate total parameter count from a transformers model config.
43+
44+
Uses hidden_size, intermediate_size, num_hidden_layers, and vocab_size
45+
to compute a rough parameter count. Falls back to a simple
46+
hidden_size^2 * num_layers heuristic when fields are missing.
47+
"""
48+
hidden = getattr(config, "hidden_size", None)
49+
intermediate = getattr(config, "intermediate_size", None)
50+
num_layers = getattr(config, "num_hidden_layers", None)
51+
vocab_size = getattr(config, "vocab_size", None)
52+
num_attention_heads = getattr(config, "num_attention_heads", None)
53+
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
54+
55+
if hidden is None or num_layers is None:
56+
return None
57+
58+
# Attention: Q, K, V projections + output projection
59+
head_dim = hidden // num_attention_heads if num_attention_heads else hidden
60+
q_params = hidden * hidden # Q projection
61+
k_params = hidden * (num_key_value_heads * head_dim if num_key_value_heads else hidden)
62+
v_params = k_params
63+
o_params = hidden * hidden # output projection
64+
attn_params = q_params + k_params + v_params + o_params
65+
66+
# FFN: gate + up + down projections (for gated architectures like LLaMA)
67+
if intermediate is not None:
68+
ffn_params = 3 * hidden * intermediate # gate_proj + up_proj + down_proj
69+
else:
70+
ffn_params = 4 * hidden * hidden # classic 4x expansion
71+
72+
# Per-layer params (attention + ffn + layer norms)
73+
layer_params = attn_params + ffn_params + 2 * hidden # 2 layer norms
74+
75+
total = num_layers * layer_params
76+
77+
# Embedding + LM head
78+
if vocab_size is not None:
79+
embedding_params = vocab_size * hidden
80+
# Most models tie embeddings and lm_head
81+
tie_word_embeddings = getattr(config, "tie_word_embeddings", True)
82+
if tie_word_embeddings:
83+
total += embedding_params
84+
else:
85+
total += 2 * embedding_params
86+
87+
return total
88+
89+
90+
def _format_bytes(num_bytes):
91+
"""Format byte count as a human-readable string."""
92+
if num_bytes >= 1e12:
93+
return f"{num_bytes / 1e12:.2f} TB"
94+
if num_bytes >= 1e9:
95+
return f"{num_bytes / 1e9:.2f} GB"
96+
if num_bytes >= 1e6:
97+
return f"{num_bytes / 1e6:.2f} MB"
98+
return f"{num_bytes / 1e3:.2f} KB"
99+
100+
101+
def _format_time(seconds):
102+
"""Format seconds as a human-readable time string."""
103+
if seconds >= 3600:
104+
hours = seconds / 3600
105+
return f"{hours:.1f} hours"
106+
if seconds >= 60:
107+
minutes = seconds / 60
108+
return f"{minutes:.1f} minutes"
109+
return f"{seconds:.0f} seconds"
110+
111+
112+
def estimate_vram(param_count, model_dtype_bytes, batch_size, seqlen, hidden_size):
113+
"""Estimate peak VRAM usage in bytes during quantization.
114+
115+
This accounts for:
116+
- Model weights in the original dtype
117+
- Optimizer state and gradients for one block
118+
- Calibration activations (batch_size * seqlen * hidden_size)
119+
- CUDA overhead and fragmentation (~20% buffer)
120+
"""
121+
# Model weights
122+
model_bytes = param_count * model_dtype_bytes
123+
124+
# Activation memory for calibration (rough upper bound for one block)
125+
activation_bytes = batch_size * seqlen * hidden_size * model_dtype_bytes
126+
127+
# Optimizer state: roughly 2x one block's parameters (momentum + variance for Adam)
128+
# Approximate one block as total_params / num_layers
129+
block_overhead = model_bytes * 0.05 # ~5% of model for one block's optimizer state
130+
131+
# CUDA overhead and fragmentation buffer (~20%)
132+
subtotal = model_bytes + activation_bytes + block_overhead
133+
total = subtotal * 1.2
134+
135+
return int(total)
136+
137+
138+
def estimate_output_size(param_count, target_bits, group_size):
139+
"""Estimate output file size in bytes for the quantized model.
140+
141+
Accounts for quantized weights plus scale/zero-point overhead.
142+
"""
143+
# Quantized weight bits
144+
weight_bits = param_count * target_bits
145+
146+
# Scale and zero-point overhead (one fp16 scale per group, one zp per group)
147+
if group_size > 0:
148+
num_groups = math.ceil(param_count / group_size)
149+
# fp16 scale (2 bytes) + zero-point packed into target_bits
150+
overhead_bits = num_groups * (16 + target_bits)
151+
else:
152+
overhead_bits = 0
153+
154+
total_bits = weight_bits + overhead_bits
155+
return int(math.ceil(total_bits / 8))
156+
157+
158+
def estimate_time(num_layers, iters, nsamples, batch_size):
159+
"""Estimate approximate quantization time in seconds.
160+
161+
Based on empirical measurements - actual time varies significantly
162+
by hardware, model architecture, and sequence length.
163+
"""
164+
batches_per_iter = math.ceil(nsamples / batch_size)
165+
total_seconds = num_layers * iters * batches_per_iter * _SECS_PER_LAYER_PER_ITER
166+
return total_seconds
167+
168+
169+
def dry_run_estimate(model_name, scheme_bits, group_size, model_dtype="float16",
170+
batch_size=8, seqlen=2048, nsamples=128, iters=200,
171+
trust_remote_code=True, platform="hf"):
172+
"""Run a dry-run estimation and return a dict of estimates.
173+
174+
Args:
175+
model_name: HuggingFace model name or local path.
176+
scheme_bits: Target quantization bit width (e.g. 4 for W4A16).
177+
group_size: Quantization group size.
178+
model_dtype: Original model data type string.
179+
batch_size: Calibration batch size.
180+
seqlen: Calibration sequence length.
181+
nsamples: Number of calibration samples.
182+
iters: Number of tuning iterations.
183+
trust_remote_code: Whether to trust remote code when loading config.
184+
platform: Platform to load model config from.
185+
186+
Returns:
187+
dict with keys: param_count, peak_vram_bytes, output_size_bytes,
188+
estimated_time_secs, and their formatted string versions.
189+
"""
190+
if platform == "model_scope":
191+
from modelscope import AutoConfig
192+
else:
193+
from transformers import AutoConfig
194+
195+
config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
196+
197+
param_count = _count_parameters(config)
198+
if param_count is None:
199+
logger.warning("Could not estimate parameter count from model config.")
200+
return None
201+
202+
hidden_size = getattr(config, "hidden_size", 4096)
203+
num_layers = getattr(config, "num_hidden_layers", 32)
204+
205+
dtype_bytes = DTYPE_BYTES.get(model_dtype, 2)
206+
207+
peak_vram = estimate_vram(param_count, dtype_bytes, batch_size, seqlen, hidden_size)
208+
output_size = estimate_output_size(param_count, scheme_bits, group_size)
209+
est_time = estimate_time(num_layers, iters, nsamples, batch_size)
210+
211+
return {
212+
"model_name": model_name,
213+
"param_count": param_count,
214+
"param_count_str": f"{param_count / 1e9:.2f}B" if param_count >= 1e9 else f"{param_count / 1e6:.1f}M",
215+
"peak_vram_bytes": peak_vram,
216+
"peak_vram_str": _format_bytes(peak_vram),
217+
"output_size_bytes": output_size,
218+
"output_size_str": _format_bytes(output_size),
219+
"estimated_time_secs": est_time,
220+
"estimated_time_str": _format_time(est_time),
221+
"scheme_bits": scheme_bits,
222+
"group_size": group_size,
223+
"model_dtype": model_dtype,
224+
"batch_size": batch_size,
225+
"seqlen": seqlen,
226+
"nsamples": nsamples,
227+
"iters": iters,
228+
"num_layers": num_layers,
229+
}
230+
231+
232+
def print_dry_run_report(estimates):
233+
"""Print a formatted dry-run estimation report to stdout."""
234+
if estimates is None:
235+
logger.error("Dry-run estimation failed: could not determine model parameters.")
236+
return
237+
238+
border = "=" * 60
239+
print(f"\n{border}")
240+
print(" AutoRound Dry-Run Estimation")
241+
print(border)
242+
print(f" Model: {estimates['model_name']}")
243+
print(f" Parameters: {estimates['param_count_str']}")
244+
print(f" Layers: {estimates['num_layers']}")
245+
print(f" Target bits: {estimates['scheme_bits']}")
246+
print(f" Group size: {estimates['group_size']}")
247+
print(f" Model dtype: {estimates['model_dtype']}")
248+
print(border)
249+
print(f" Estimated peak VRAM: {estimates['peak_vram_str']}")
250+
print(f" Estimated output size: {estimates['output_size_str']}")
251+
print(f" Estimated time: {estimates['estimated_time_str']}")
252+
print(f" (batch_size={estimates['batch_size']}, seqlen={estimates['seqlen']}, "
253+
f"nsamples={estimates['nsamples']}, iters={estimates['iters']})")
254+
print(border)
255+
print(" NOTE: These are rough estimates. Actual values depend on")
256+
print(" hardware, model architecture, and runtime conditions.")
257+
print(f"{border}\n")

0 commit comments

Comments
 (0)