Skip to content

Commit 3e99519

Browse files
committed
Set attn_implementation="eager" for ViT, BERT and BGE-M3
1 parent 24c5cdd commit 3e99519

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

benchmark/tt-xla/encoders.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from benchmark.utils import aggregate_ttnn_perf_metrics, sanitize_filename
1111
from encoder_benchmark import benchmark_encoder_torch_xla
12-
from utils import apply_mean_pooling, apply_last_token_pooling
12+
from utils import apply_mean_pooling, apply_last_token_pooling, patch_transformers_for_eager_attn
1313

1414

1515
DTYPE_MAP = {
@@ -156,6 +156,11 @@ def test_encoder(
156156
def test_bert(output_file):
157157
from third_party.tt_forge_models.bert.sentence_embedding_generation.pytorch.loader import ModelLoader
158158

159+
# TODO(vkovacevic): Issue #804
160+
from transformers import BertModel
161+
162+
patch_transformers_for_eager_attn(BertModel)
163+
159164
# Configuration
160165
data_format = "bfloat16"
161166
input_sequence_length = 384
@@ -307,9 +312,15 @@ def test_bge_m3(output_file):
307312
import torch
308313
import numpy as np
309314
from collections import defaultdict
315+
310316
from third_party.tt_forge_models.bge_m3.encode.pytorch.loader import ModelLoader
311317
from FlagEmbedding import BGEM3FlagModel
312318

319+
# TODO(vkovacevic): Issue #804
320+
from transformers import XLMRobertaModel
321+
322+
patch_transformers_for_eager_attn(XLMRobertaModel)
323+
313324
# Configuration
314325
data_format = "float32"
315326
input_sequence_length = 512

benchmark/tt-xla/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,27 @@ def move_to_cpu(data):
377377
moved = [move_to_cpu(item) for item in data]
378378
return type(data)(moved)
379379
return data
380+
381+
382+
# TODO(vkovacevic): Issue #804
383+
def patch_transformers_for_eager_attn(cls):
384+
"""Monkey patch a transformers model class to use eager attention implementation.
385+
386+
This patches the from_pretrained method to inject attn_implementation="eager",
387+
which disables SDPA and other optimized attention implementations that may not
388+
be compatible with certain backends.
389+
390+
Args:
391+
cls: A transformers model class (e.g., ViTForImageClassification, BertModel)
392+
"""
393+
from functools import wraps
394+
395+
original_from_pretrained = cls.from_pretrained.__func__
396+
397+
@classmethod
398+
@wraps(original_from_pretrained)
399+
def patched_from_pretrained(cls, *args, **kwargs):
400+
kwargs.setdefault("attn_implementation", "eager")
401+
return original_from_pretrained(cls, *args, **kwargs)
402+
403+
cls.from_pretrained = patched_from_pretrained

benchmark/tt-xla/vision_models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77

88
from benchmark.utils import aggregate_ttnn_perf_metrics, sanitize_filename
9+
from utils import patch_transformers_for_eager_attn
910
from vision_benchmark import benchmark_vision_torch_xla
1011

1112
# Defaults for all vision models
@@ -229,6 +230,11 @@ def test_unet(output_file):
229230
def test_vit(output_file):
230231
from third_party.tt_forge_models.vit.pytorch.loader import ModelLoader, ModelVariant
231232

233+
# TODO(vkovacevic): Issue #804
234+
from transformers import ViTForImageClassification
235+
236+
patch_transformers_for_eager_attn(ViTForImageClassification)
237+
232238
variant = ModelVariant.BASE
233239
read_logits_fn = lambda output: output.logits
234240
test_vision(

0 commit comments

Comments
 (0)