Skip to content
Open
45 changes: 23 additions & 22 deletions src/llmcompressor/modifiers/gptq/gptq_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,19 @@ def quantize_weight(

scale, zero_point = observer(W)
# handle g_idx and activation ordering
if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
g_idx_to_save = None
if strategy in (
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.BLOCK,
):
# mapping from column index to group index
g_idx = (
torch.arange(num_columns, device=W.device, dtype=torch.int)
// quant_args.group_size
divisor = (
quant_args.group_size
Comment thread
rk119 marked this conversation as resolved.
if strategy != QuantizationStrategy.BLOCK
else quant_args.block_structure[1]
)
g_idx = torch.arange(num_columns, device=W.device, dtype=torch.int) // divisor

if actorder == ActivationOrdering.GROUP:
W, H, perm = _apply_activation_ordering(W, H)
Expand Down Expand Up @@ -217,8 +224,8 @@ def quantize_weight(
global_scale=global_scale,
)
elif strategy == QuantizationStrategy.BLOCK:
block_width = quant_args.block_structure[1]
block_column_idx = (i1 + i) // block_width
column_idx = i1 + i
Comment thread
rk119 marked this conversation as resolved.
block_column_idx = g_idx[column_idx]
q = fake_quantize(
q.unsqueeze(1),
scale[:, block_column_idx : block_column_idx + 1],
Expand Down Expand Up @@ -253,24 +260,18 @@ def quantize_weight(
else:
W[:, i2:] -= w_err

has_gidx = False
if strategy in (QuantizationStrategy.GROUP, QuantizationStrategy.TENSOR_GROUP):
if actorder == ActivationOrdering.WEIGHT:
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]

elif actorder == ActivationOrdering.GROUP:
if strategy in (
Comment thread
rk119 marked this conversation as resolved.
Outdated
QuantizationStrategy.GROUP,
QuantizationStrategy.TENSOR_GROUP,
QuantizationStrategy.BLOCK,
):
if actorder in (ActivationOrdering.WEIGHT, ActivationOrdering.GROUP):
# restore original permutation
invperm = torch.argsort(perm)
W = W[:, invperm]
g_idx = g_idx[invperm]

# only save g_idx if mapping is not identity
has_gidx = True

if not has_gidx:
g_idx = None
if actorder == ActivationOrdering.GROUP:
Copy link
Copy Markdown
Collaborator

@HDCharles HDCharles Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify this, only group act order saves g_idx, can just check for that
No need to check twice, just do this line on line 287 and remove g_idx_to_save

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, done!

g_idx_to_save = g_idx[invperm]

if isinstance(module, transformers.Conv1D):
W.transpose_(0, 1)
Expand All @@ -282,8 +283,8 @@ def quantize_weight(
"weight_scale": scale.to(dtype=final_dtype),
"weight_zero_point": zero_point.to(dtype=quant_args.zp_dtype),
}
if g_idx is not None:
q_param_dict["weight_g_idx"] = g_idx
if g_idx_to_save is not None:
q_param_dict["weight_g_idx"] = g_idx_to_save
return (loss, q_param_dict)


Expand Down
33 changes: 32 additions & 1 deletion tests/llmcompressor/transformers/gptq/test_gptq_oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,35 @@
},
)

# Test block quantization variant
recipe_modifier_full_block = GPTQModifier(
ignore=["lm_head"],
config_groups={
"group_0": QuantizationScheme(
targets=["re:.*model.layers.2.self_attn.q_proj$"],
weights=QuantizationArgs(
num_bits=8,
strategy="block",
block_structure=[2, 8],
),
)
},
)

recipe_modifier_block_actorder_weight = GPTQModifier(
ignore=["lm_head"],
config_groups={
"group_0": QuantizationScheme(
targets=["re:.*model.layers.2.self_attn.q_proj$"],
weights=QuantizationArgs(
num_bits=8,
strategy="block",
block_structure=[2, 8],
actorder=ActivationOrdering.WEIGHT,
),
)
},
)
Comment thread
rk119 marked this conversation as resolved.
Outdated

@pytest.mark.parametrize(
"recipe",
Expand All @@ -107,6 +136,8 @@
recipe_modifier_shorthand_b,
recipe_modifier_group_actorder_weight,
recipe_modifier_group_actorder_group,
recipe_modifier_full_block,
recipe_modifier_block_actorder_weight,
],
)
def test_oneshot_application(recipe, tmp_path):
Expand Down Expand Up @@ -154,7 +185,7 @@ def test_oneshot_application(recipe, tmp_path):
assert quant_scheme.targets == ["re:.*model.layers.2.self_attn.q_proj$"]
weight_args = quantization_config.config_groups["group_0"].weights
assert isinstance(weight_args, QuantizationArgs)
assert weight_args.num_bits == 4
assert weight_args.num_bits == 4 or weight_args.num_bits == 8

# Check a specific layer is quantized
targetted_linear_layer = model_loaded.model.layers[2].self_attn.q_proj
Expand Down
Loading