refactor: use config.is_causal=False for bidirectional attention#1837
Draft
oliverholworthy wants to merge 4 commits intomainfrom
Draft
refactor: use config.is_causal=False for bidirectional attention#1837oliverholworthy wants to merge 4 commits intomainfrom
oliverholworthy wants to merge 4 commits intomainfrom
Conversation
Leverage the transformers 5.3+ "power feature" in create_causal_mask that redirects to create_bidirectional_mask when config.is_causal is False, eliminating the need for a custom forward() override in LlamaBidirectionalModel. Changes: - Remove LlamaBidirectionalModel.forward() override; set config.is_causal=False in __init__ so the parent LlamaModel.forward() produces bidirectional masks automatically - Pass is_causal=False from BiEncoderModel.encode() to the model forward, which propagates through **kwargs to sdpa/flash attention functions — load-bearing for FA2 correctness - Add extract_submodel parameter to build_encoder_backbone for generic VLM text backbone extraction via dotted attribute path - Set config.is_causal=False and layer.self_attn.is_causal=False on models loaded through the generic (non-SUPPORTED_BACKBONES) path Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
…extract_submodel Move new is_causal tests into the existing test file to match repo conventions (test files named after what they test, not the change). Remove extract_submodel from this PR — it will be a separate PR. Signed-off-by: Oliver Holworthy <oholworthy@nvidia.com> Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
3 tasks
Move tests for BiEncoderModel.encode() and build_encoder_backbone() to tests/unit_tests/_transformers/test_retrieval.py to match the source module location. Keep LlamaBidirectionalModel-specific tests in test_llama_bidirectional_model.py. Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
Contributor
|
/ok to test 1cc9623 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Leverage the transformers 5.3+
create_causal_maskpower feature (config.is_causal) to remove the customforward()override fromLlamaBidirectionalModel, and make the genericbuild_encoder_backbonepath produce bidirectional models without per-architecture custom classes.Changelog
LlamaBidirectionalModel.forward()override; setconfig.is_causal = Falsein__init__so the parentLlamaModel.forward()produces bidirectional masks automatically viacreate_causal_maskredirectis_causal=Falsekwarg fromBiEncoderModel.encode()— propagates through**kwargstosdpa_attention_forward/flash_attention_forward, load-bearing for FA2 correctnessconfig.is_causal = Falseandlayer.self_attn.is_causal = Falseon models loaded through the generic (non-SUPPORTED_BACKBONES) path inbuild_encoder_backboneDynamicCache,create_bidirectional_mask,check_model_inputs, etc.)Before your PR is "Ready for review"
Pre checks:
Additional Information
Two independent mechanisms must both be non-causal for correct bidirectional attention:
create_causal_maskreadsgetattr(config, "is_causal", True)and redirects tocreate_bidirectional_maskwhenFalse(masking_utils.py:873)is_causalas a named parameter that takes precedence overmodule.is_causal. The kwarg fromencode()handles the case where a standardAutoModel.from_pretrainedloads a checkpoint withconfig.is_causal=Falsebutmodule.is_causalremainsTrue(hardcoded inLlamaAttention.__init__)This works for any HF model that calls
create_causal_mask(config=self.config, ...)in itsforward()and propagates**kwargsthrough to attention functions (verified for Llama and Ministral3).VLM submodel extraction (
extract_submodel) will follow in a separate PR.