@@ -718,6 +718,86 @@ def hook_fn(grad, hook=None):
718718 assert hook_called ["bridge" ], "TransformerBridge backward hook should now be called correctly"
719719
720720
721+ def test_AttentionBridge_preserves_fp_input_when_first_param_is_quantized ():
722+ """Bridge must not cast fp inputs to integer storage dtype.
723+
724+ Regression for an AttentionBridge / GeneralizedComponent bug where
725+ `target_dtype = next(parameters()).dtype` returned the storage dtype of
726+ quantized weights (uint8 for BnB Params4bit, int32 for GPTQ, etc.). When
727+ the first parameter happened to be quantized, bridge cast fp32 hidden_states
728+ to that integer dtype before passing them to HF — destroying precision and
729+ producing gibberish logits on every quantized model.
730+
731+ Fakes a "quantized first parameter" by replacing q_proj.weight with a
732+ uint8 tensor, then runs a forward and asserts the input the original
733+ component receives is still floating-point.
734+ """
735+ from transformer_lens .model_bridge .generalized_components .attention import (
736+ AttentionBridge ,
737+ )
738+
739+ # Use tiny Mistral — it's a plain AttentionBridge (not JointQKV).
740+ bridge : TransformerBridge = TransformerBridge .boot_transformers ( # type: ignore
741+ "trl-internal-testing/tiny-MistralForCausalLM-0.2" , device = "cpu"
742+ )
743+
744+ attn_bridge = bridge .blocks [0 ].attn # type: ignore[attr-defined]
745+ assert (
746+ type (attn_bridge ).__name__ == "AttentionBridge"
747+ ), f"Expected plain AttentionBridge, got { type (attn_bridge ).__name__ } "
748+ assert isinstance (attn_bridge , AttentionBridge )
749+
750+ original = attn_bridge .original_component
751+ assert original is not None , "AttentionBridge.original_component not set"
752+
753+ # Fake-quantize q_proj to uint8 storage — mirrors BnB Params4bit.
754+ fp_weight = original .q_proj .weight
755+ original .q_proj .weight = torch .nn .Parameter (
756+ torch .zeros (fp_weight .shape , dtype = torch .uint8 ), requires_grad = False
757+ )
758+ assert (
759+ next (original .parameters ()).dtype == torch .uint8
760+ ), "Test setup: first param should be uint8 to trigger the bug condition"
761+
762+ # Capture what dtype reaches the original component's forward.
763+ received_dtype : list = []
764+ orig_forward = original .forward
765+
766+ def capture (* args , ** kwargs ):
767+ if "hidden_states" in kwargs :
768+ received_dtype .append (kwargs ["hidden_states" ].dtype )
769+ elif args :
770+ received_dtype .append (args [0 ].dtype )
771+ # Don't actually run forward — fake-quantized weight would error.
772+ # Return a shape-compatible dummy. HF Mistral attention returns a tuple.
773+ bsz , seq , d_model = (kwargs .get ("hidden_states" , args [0 ] if args else None )).shape
774+ n_heads = bridge .cfg .n_heads # type: ignore[attr-defined]
775+ return (
776+ torch .zeros (bsz , seq , d_model , dtype = torch .float32 ),
777+ torch .zeros (bsz , n_heads , seq , seq , dtype = torch .float32 ),
778+ )
779+
780+ original .forward = capture # type: ignore[method-assign]
781+ try :
782+ test_input = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
783+ with torch .no_grad ():
784+ try :
785+ bridge (test_input )
786+ except Exception :
787+ pass # downstream may fail; we only care what reached attn forward
788+ finally :
789+ original .forward = orig_forward # type: ignore[method-assign]
790+ original .q_proj .weight = fp_weight
791+
792+ assert len (received_dtype ) > 0 , "Original attention forward never called"
793+ for dt in received_dtype :
794+ assert dt .is_floating_point , (
795+ f"Bridge passed dtype={ dt } to original attention forward, but it must be "
796+ f"floating point. Regression of the AttentionBridge dtype-cast bug — "
797+ f"target_dtype must skip non-fp (quantized-storage) parameters."
798+ )
799+
800+
721801@pytest .mark .skipif (bool (os .getenv ("CI" )), reason = "Skip Gemma2 test in CI to avoid timeout" )
722802def test_TransformerBridge_gemma2_forward ():
723803 """Test that TransformerBridge properly handles Gemma2's position_embeddings.
0 commit comments