Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions nemo_automodel/_transformers/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,20 @@ def build_encoder_backbone(
# Fallback: use HuggingFace Auto classes for model types not in SUPPORTED_BACKBONES
logger.info(f"Model type '{model_type}' not in SUPPORTED_BACKBONES; falling back to HuggingFace Auto classes")
if task == "score":
return AutoModelForSequenceClassification.from_pretrained(
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs
)
return AutoModel.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs)
else:
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **hf_kwargs)

# Make the backbone bidirectional: config flag for mask generation,
# module flag for SDPA/FA2 kernel fallback.
model.config.is_causal = False
for layer in getattr(model, "layers", []):
if hasattr(layer, "self_attn"):
layer.self_attn.is_causal = False

return model


def save_encoder_pretrained(model: nn.Module, save_directory: str, **kwargs) -> None:
Expand Down Expand Up @@ -268,6 +278,7 @@ def encode(self, input_dict: dict) -> Optional[torch.Tensor]:

outputs = self.model(
**{k: v for k, v in input_dict.items() if k not in ["kd_labels"]},
is_causal=False,
return_dict=True,
output_hidden_states=True,
)
Expand Down
76 changes: 5 additions & 71 deletions nemo_automodel/components/models/llama_bidirectional/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,10 @@
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.masking_utils import create_bidirectional_mask
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from transformers.cache_utils import Cache
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs

try:
from nemo_automodel.shared.import_utils import get_check_model_inputs_decorator

check_model_inputs = get_check_model_inputs_decorator()
except ImportError:
# Fallback to no-op decorator if import fails
def check_model_inputs(func):
return func


class LlamaBidirectionalConfig(LlamaConfig):
Expand Down Expand Up @@ -96,66 +84,12 @@ def __init__(self, config: LlamaConfig):
config: Model configuration
"""
super().__init__(config)
# Disable causal attention for all layers
# Enable bidirectional attention: config flag for mask generation,
# module flag for SDPA/FA2 kernel fallback.
config.is_causal = False
for layer in self.layers:
layer.self_attn.is_causal = False

@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

bidirectional_mask = create_bidirectional_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
)

hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=bidirectional_mask,
position_ids=position_ids,
past_key_values=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)

hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)


def _pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor, pool_type: str) -> torch.Tensor:
"""Pool hidden states using the specified pooling method."""
Expand Down
122 changes: 122 additions & 0 deletions tests/unit_tests/_transformers/test_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for nemo_automodel._transformers.retrieval (build_encoder_backbone, BiEncoderModel, etc.)."""

from unittest.mock import MagicMock

import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutputWithPast

from nemo_automodel._transformers.retrieval import BiEncoderModel

# ---------------------------------------------------------------------------
# BiEncoderModel.encode() passes is_causal=False
# ---------------------------------------------------------------------------


class _SpyModel(nn.Module):
"""A fake model that records the kwargs passed to forward()."""

def __init__(self, hidden_size=16):
super().__init__()
self.config = MagicMock(hidden_size=hidden_size)
self.captured_kwargs = {}

def forward(self, input_ids=None, attention_mask=None, **kwargs):
self.captured_kwargs = dict(kwargs)
bsz, seq = input_ids.shape
h = self.config.hidden_size
last = torch.ones(bsz, seq, h)
return BaseModelOutputWithPast(
last_hidden_state=last,
hidden_states=[last],
)


def test_encode_passes_is_causal_false():
"""BiEncoderModel.encode() must pass is_causal=False to the model's
forward() so FA2/SDPA kernels don't apply causal masking."""
spy = _SpyModel(hidden_size=16)
encoder = BiEncoderModel(model=spy, pooling="avg", l2_normalize=False)

input_dict = {
"input_ids": torch.ones(2, 4, dtype=torch.long),
"attention_mask": torch.ones(2, 4, dtype=torch.long),
}
encoder.encode(input_dict)

assert "is_causal" in spy.captured_kwargs, "encode() must pass is_causal kwarg to model forward"
assert spy.captured_kwargs["is_causal"] is False


# ---------------------------------------------------------------------------
# build_encoder_backbone sets is_causal flags on generic models
# ---------------------------------------------------------------------------


class _FakeAttention(nn.Module):
def __init__(self):
super().__init__()
self.is_causal = True


class _FakeDecoderLayer(nn.Module):
def __init__(self):
super().__init__()
self.self_attn = _FakeAttention()


class _FakeDecoderModel(nn.Module):
def __init__(self):
super().__init__()
self.config = type("Cfg", (), {"model_type": "fake_decoder"})()
self.layers = nn.ModuleList([_FakeDecoderLayer(), _FakeDecoderLayer()])


def _mock_generic_automodel(monkeypatch):
from nemo_automodel._transformers import retrieval

fake_model = _FakeDecoderModel()
monkeypatch.setattr(
retrieval,
"AutoModel",
type("AutoModel", (), {"from_pretrained": staticmethod(lambda *a, **kw: fake_model)}),
)
monkeypatch.setattr(
retrieval,
"AutoConfig",
type("AutoConfig", (), {"from_pretrained": staticmethod(lambda *a, **kw: type("Cfg", (), {"model_type": "fake_decoder"})())}),
)
return fake_model


def test_build_encoder_backbone_sets_config_is_causal(monkeypatch):
"""build_encoder_backbone must set config.is_causal = False on generic models."""
_mock_generic_automodel(monkeypatch)
from nemo_automodel._transformers import retrieval

result = retrieval.build_encoder_backbone("fake/path", task="embedding")
assert result.config.is_causal is False


def test_build_encoder_backbone_sets_attention_is_causal(monkeypatch):
"""build_encoder_backbone must set module.is_causal = False on all attention layers."""
_mock_generic_automodel(monkeypatch)
from nemo_automodel._transformers import retrieval

result = retrieval.build_encoder_backbone("fake/path", task="embedding")
for layer in result.layers:
assert layer.self_attn.is_causal is False
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_outputs import SequenceClassifierOutputWithPast
from transformers.modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
from transformers.models.llama.modeling_llama import LlamaModel

from nemo_automodel._transformers.registry import ModelRegistry
from nemo_automodel._transformers.retrieval import (
BiEncoderModel,
CrossEncoderModel,
configure_encoder_metadata,
_init_encoder_common,
configure_encoder_metadata,
pool,
)
from nemo_automodel.recipes.retrieval.train_bi_encoder import contrastive_scores_and_labels
from nemo_automodel._transformers.registry import ModelRegistry
from nemo_automodel.components.models.llama_bidirectional.model import (
LlamaBidirectionalConfig,
LlamaBidirectionalForSequenceClassification,
LlamaBidirectionalModel,
)
from transformers.modeling_outputs import BaseModelOutputWithPast
from nemo_automodel.recipes.retrieval.train_bi_encoder import contrastive_scores_and_labels


def test_contrastive_scores_and_labels_shapes_and_labels():
Expand Down Expand Up @@ -548,3 +548,67 @@ def __init__(self):
_init_encoder_common(encoder, fake)

assert encoder.name_or_path == "Qwen/Qwen3-1.7B"


# ---------------------------------------------------------------------------
# is_causal=False refactor: config, forward delegation, encode kwarg,
# extract_submodel, and generic-path is_causal flags
# ---------------------------------------------------------------------------


def _tiny_bidirec_config(**overrides):
defaults = dict(
vocab_size=32,
hidden_size=16,
num_hidden_layers=1,
num_attention_heads=1,
num_key_value_heads=1,
intermediate_size=32,
pad_token_id=0,
)
defaults.update(overrides)
return LlamaBidirectionalConfig(**defaults)


def test_config_is_causal_set_to_false():
"""config.is_causal must be False after init — required for create_causal_mask
to redirect to create_bidirectional_mask in the parent forward()."""
cfg = _tiny_bidirec_config()
model = LlamaBidirectionalModel(cfg)
assert model.config.is_causal is False


def test_no_forward_override():
"""LlamaBidirectionalModel must NOT define its own forward().
It relies on the parent LlamaModel.forward() which calls create_causal_mask
and respects config.is_causal = False."""
assert "forward" not in LlamaBidirectionalModel.__dict__
assert LlamaBidirectionalModel.forward is LlamaModel.forward


def test_bidirectional_output_via_parent_forward():
"""Parent forward() should produce bidirectional attention when
config.is_causal = False — changing a later token must affect
an earlier token's hidden state."""
cfg = _tiny_bidirec_config()
model = LlamaBidirectionalModel(cfg)
model.eval()

ids_a = torch.tensor([[1, 2, 3, 4]])
ids_b = torch.tensor([[1, 2, 3, 5]]) # only last token differs
mask = torch.ones_like(ids_a)

with torch.no_grad():
out_a = model(input_ids=ids_a, attention_mask=mask)
out_b = model(input_ids=ids_b, attention_mask=mask)

# In bidirectional attention, changing token 4 affects ALL positions,
# including position 0. In causal attention, position 0 would be identical.
hidden_a = out_a.last_hidden_state[0, 0]
hidden_b = out_b.last_hidden_state[0, 0]
assert not torch.allclose(hidden_a, hidden_b, atol=1e-5), (
"Position 0 hidden state should differ when a later token changes "
"(bidirectional attention). If identical, attention is still causal."
)


Loading