Skip to content

Commit 0a5218c

Browse files
authored
Fixed Quantization bug in TransformerLens 3.0 (#1276)
* Fixed Quantization bug in TransformerLens 3.0 * Format fixes
1 parent fd288dc commit 0a5218c

3 files changed

Lines changed: 98 additions & 8 deletions

File tree

tests/integration/model_bridge/test_bridge_integration.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,86 @@ def hook_fn(grad, hook=None):
718718
assert hook_called["bridge"], "TransformerBridge backward hook should now be called correctly"
719719

720720

721+
def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized():
722+
"""Bridge must not cast fp inputs to integer storage dtype.
723+
724+
Regression for an AttentionBridge / GeneralizedComponent bug where
725+
`target_dtype = next(parameters()).dtype` returned the storage dtype of
726+
quantized weights (uint8 for BnB Params4bit, int32 for GPTQ, etc.). When
727+
the first parameter happened to be quantized, bridge cast fp32 hidden_states
728+
to that integer dtype before passing them to HF — destroying precision and
729+
producing gibberish logits on every quantized model.
730+
731+
Fakes a "quantized first parameter" by replacing q_proj.weight with a
732+
uint8 tensor, then runs a forward and asserts the input the original
733+
component receives is still floating-point.
734+
"""
735+
from transformer_lens.model_bridge.generalized_components.attention import (
736+
AttentionBridge,
737+
)
738+
739+
# Use tiny Mistral — it's a plain AttentionBridge (not JointQKV).
740+
bridge: TransformerBridge = TransformerBridge.boot_transformers( # type: ignore
741+
"trl-internal-testing/tiny-MistralForCausalLM-0.2", device="cpu"
742+
)
743+
744+
attn_bridge = bridge.blocks[0].attn # type: ignore[attr-defined]
745+
assert (
746+
type(attn_bridge).__name__ == "AttentionBridge"
747+
), f"Expected plain AttentionBridge, got {type(attn_bridge).__name__}"
748+
assert isinstance(attn_bridge, AttentionBridge)
749+
750+
original = attn_bridge.original_component
751+
assert original is not None, "AttentionBridge.original_component not set"
752+
753+
# Fake-quantize q_proj to uint8 storage — mirrors BnB Params4bit.
754+
fp_weight = original.q_proj.weight
755+
original.q_proj.weight = torch.nn.Parameter(
756+
torch.zeros(fp_weight.shape, dtype=torch.uint8), requires_grad=False
757+
)
758+
assert (
759+
next(original.parameters()).dtype == torch.uint8
760+
), "Test setup: first param should be uint8 to trigger the bug condition"
761+
762+
# Capture what dtype reaches the original component's forward.
763+
received_dtype: list = []
764+
orig_forward = original.forward
765+
766+
def capture(*args, **kwargs):
767+
if "hidden_states" in kwargs:
768+
received_dtype.append(kwargs["hidden_states"].dtype)
769+
elif args:
770+
received_dtype.append(args[0].dtype)
771+
# Don't actually run forward — fake-quantized weight would error.
772+
# Return a shape-compatible dummy. HF Mistral attention returns a tuple.
773+
bsz, seq, d_model = (kwargs.get("hidden_states", args[0] if args else None)).shape
774+
n_heads = bridge.cfg.n_heads # type: ignore[attr-defined]
775+
return (
776+
torch.zeros(bsz, seq, d_model, dtype=torch.float32),
777+
torch.zeros(bsz, n_heads, seq, seq, dtype=torch.float32),
778+
)
779+
780+
original.forward = capture # type: ignore[method-assign]
781+
try:
782+
test_input = torch.tensor([[1, 2, 3, 4, 5]])
783+
with torch.no_grad():
784+
try:
785+
bridge(test_input)
786+
except Exception:
787+
pass # downstream may fail; we only care what reached attn forward
788+
finally:
789+
original.forward = orig_forward # type: ignore[method-assign]
790+
original.q_proj.weight = fp_weight
791+
792+
assert len(received_dtype) > 0, "Original attention forward never called"
793+
for dt in received_dtype:
794+
assert dt.is_floating_point, (
795+
f"Bridge passed dtype={dt} to original attention forward, but it must be "
796+
f"floating point. Regression of the AttentionBridge dtype-cast bug — "
797+
f"target_dtype must skip non-fp (quantized-storage) parameters."
798+
)
799+
800+
721801
@pytest.mark.skipif(bool(os.getenv("CI")), reason="Skip Gemma2 test in CI to avoid timeout")
722802
def test_TransformerBridge_gemma2_forward():
723803
"""Test that TransformerBridge properly handles Gemma2's position_embeddings.

transformer_lens/model_bridge/generalized_components/attention.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -622,11 +622,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
622622
raise RuntimeError(
623623
f"Original component not set for {self.name}. Call set_original_component() first."
624624
)
625+
# Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
626+
# HQQ, torchao) are stored in integer dtypes and dequantized internally
627+
# during matmul. The compute dtype must come from a fp parameter; casting
628+
# fp inputs to an integer storage dtype destroys precision.
625629
target_dtype = None
626-
try:
627-
target_dtype = next(self.original_component.parameters()).dtype
628-
except StopIteration:
629-
pass
630+
for p in self.original_component.parameters():
631+
if not p.dtype.is_floating_point:
632+
continue
633+
target_dtype = p.dtype
634+
break
630635
if "query_input" in kwargs:
631636
hooked = self.hook_in(kwargs["query_input"])
632637
if (

transformer_lens/model_bridge/generalized_components/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,16 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
274274
raise RuntimeError(
275275
f"Original component not set for {self.name}. Call set_original_component() first."
276276
)
277+
# Skip non-fp params: quantized weights (bnb uint8/int8, GPTQ/AWQ int32,
278+
# HQQ, torchao) are stored in integer dtypes and dequantized internally
279+
# during matmul. The compute dtype must come from a fp parameter; casting
280+
# fp inputs to an integer storage dtype destroys precision.
277281
target_dtype = None
278-
try:
279-
target_dtype = next(original_component.parameters()).dtype
280-
except StopIteration:
281-
pass
282+
for p in original_component.parameters():
283+
if not p.dtype.is_floating_point:
284+
continue
285+
target_dtype = p.dtype
286+
break
282287
input_arg_names = [
283288
"input",
284289
"hidden_states",

0 commit comments

Comments
 (0)