Skip to content

removing 8da4w-gptq #11274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 3, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 1 addition & 53 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def quantize( # noqa C901

Args:
model: The model to quantize.
qmode: The quantization mode, e.g. int8, 8da4w, 8da4w-gptq.
qmode: The quantization mode, e.g. int8, 8da4w.
computation_dtype: The dtype that ops are performed in (the resulting dtype of dequantization).
Also the dtype of the rest of the non-quantized compoents of the model.
checkpoint_dtype: The dtype of the checkpoint, this arg exists since it is more accurate to
Expand Down Expand Up @@ -161,58 +161,6 @@ def quantize( # noqa C901
if verbose:
print("quantized model:", model)
return model
elif qmode == "8da4w-gptq":
# Check for required args
required_args: Optional[Any] = [
group_size,
calibration_limit,
calibration_seq_length,
]
if any(arg is None for arg in required_args):
raise Exception(
"For 8da4w-gptq quantization, group size, calibration limit and calibration sequence length must be specified."
)
if calibration_tasks is None:
calibration_tasks = ["wikitext"]

try:
# torchao 0.3+
from torchao._models._eval import InputRecorder
except ImportError:
from torchao.quantization.GPTQ import InputRecorder # pyre-ignore

from torchao.quantization.quant_api import Int8DynActInt4WeightGPTQQuantizer

if tokenizer_path is None:
assert checkpoint_path is not None, "checkpoint_path must be specified"
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)

inputs = (
InputRecorder( # pyre-fixme[16]
tokenizer,
calibration_seq_length,
None, # input_prep_func
pad_calibration_inputs,
model.vocab_size,
)
.record_inputs(
calibration_tasks,
calibration_limit,
)
.get_inputs()
)

gptq_quantizer = Int8DynActInt4WeightGPTQQuantizer(
blocksize,
percdamp,
group_size,
) # TODO: separate computation and checkpoint dtype for GPTQ.
model = gptq_quantizer.quantize(model, inputs)
return model
elif qmode == "vulkan_4w":
from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer

Expand Down
Loading