Skip to content

refactor: use config.is_causal=False for bidirectional attention#1837

Draft
oliverholworthy wants to merge 4 commits intomainfrom
oholworthy/is_causal_refactor
Draft

refactor: use config.is_causal=False for bidirectional attention#1837
oliverholworthy wants to merge 4 commits intomainfrom
oholworthy/is_causal_refactor

Conversation

@oliverholworthy
Copy link
Copy Markdown
Contributor

@oliverholworthy oliverholworthy commented Apr 14, 2026

What does this PR do ?

Leverage the transformers 5.3+ create_causal_mask power feature (config.is_causal) to remove the custom forward() override from LlamaBidirectionalModel, and make the generic build_encoder_backbone path produce bidirectional models without per-architecture custom classes.

Changelog

  • Remove LlamaBidirectionalModel.forward() override; set config.is_causal = False in __init__ so the parent LlamaModel.forward() produces bidirectional masks automatically via create_causal_mask redirect
  • Pass is_causal=False kwarg from BiEncoderModel.encode() — propagates through **kwargs to sdpa_attention_forward / flash_attention_forward, load-bearing for FA2 correctness
  • Set config.is_causal = False and layer.self_attn.is_causal = False on models loaded through the generic (non-SUPPORTED_BACKBONES) path in build_encoder_backbone
  • Clean up unused imports (DynamicCache, create_bidirectional_mask, check_model_inputs, etc.)

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

Additional Information

Two independent mechanisms must both be non-causal for correct bidirectional attention:

  1. Mask generation: create_causal_mask reads getattr(config, "is_causal", True) and redirects to create_bidirectional_mask when False (masking_utils.py:873)
  2. SDPA/FA2 kernel: Both attention backends accept is_causal as a named parameter that takes precedence over module.is_causal. The kwarg from encode() handles the case where a standard AutoModel.from_pretrained loads a checkpoint with config.is_causal=False but module.is_causal remains True (hardcoded in LlamaAttention.__init__)

This works for any HF model that calls create_causal_mask(config=self.config, ...) in its forward() and propagates **kwargs through to attention functions (verified for Llama and Ministral3).

VLM submodel extraction (extract_submodel) will follow in a separate PR.

Leverage the transformers 5.3+ "power feature" in create_causal_mask
that redirects to create_bidirectional_mask when config.is_causal is
False, eliminating the need for a custom forward() override in
LlamaBidirectionalModel.

Changes:
- Remove LlamaBidirectionalModel.forward() override; set
  config.is_causal=False in __init__ so the parent LlamaModel.forward()
  produces bidirectional masks automatically
- Pass is_causal=False from BiEncoderModel.encode() to the model
  forward, which propagates through **kwargs to sdpa/flash attention
  functions — load-bearing for FA2 correctness
- Add extract_submodel parameter to build_encoder_backbone for generic
  VLM text backbone extraction via dotted attribute path
- Set config.is_causal=False and layer.self_attn.is_causal=False on
  models loaded through the generic (non-SUPPORTED_BACKBONES) path

Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 14, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

…extract_submodel

Move new is_causal tests into the existing test file to match repo
conventions (test files named after what they test, not the change).
Remove extract_submodel from this PR — it will be a separate PR.

Signed-off-by: Oliver Holworthy <oholworthy@nvidia.com>
Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
oliverholworthy and others added 2 commits April 14, 2026 19:42
Move tests for BiEncoderModel.encode() and build_encoder_backbone()
to tests/unit_tests/_transformers/test_retrieval.py to match the
source module location. Keep LlamaBidirectionalModel-specific tests
in test_llama_bidirectional_model.py.

Signed-off-by: Oliver Holworthy <1216955+oliverholworthy@users.noreply.github.com>
@akoumpa
Copy link
Copy Markdown
Contributor

akoumpa commented Apr 19, 2026

/ok to test 1cc9623

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants