File tree Expand file tree Collapse file tree 3 files changed +42
-1
lines changed
Expand file tree Collapse file tree 3 files changed +42
-1
lines changed Original file line number Diff line number Diff line change 99
1010from benchmark .utils import aggregate_ttnn_perf_metrics , sanitize_filename
1111from encoder_benchmark import benchmark_encoder_torch_xla
12- from utils import apply_mean_pooling , apply_last_token_pooling
12+ from utils import apply_mean_pooling , apply_last_token_pooling , patch_transformers_for_eager_attn
1313
1414
1515DTYPE_MAP = {
@@ -156,6 +156,11 @@ def test_encoder(
156156def test_bert (output_file ):
157157 from third_party .tt_forge_models .bert .sentence_embedding_generation .pytorch .loader import ModelLoader
158158
159+ # TODO(vkovacevic): Issue #804
160+ from transformers import BertModel
161+
162+ patch_transformers_for_eager_attn (BertModel )
163+
159164 # Configuration
160165 data_format = "bfloat16"
161166 input_sequence_length = 384
@@ -307,9 +312,15 @@ def test_bge_m3(output_file):
307312 import torch
308313 import numpy as np
309314 from collections import defaultdict
315+
310316 from third_party .tt_forge_models .bge_m3 .encode .pytorch .loader import ModelLoader
311317 from FlagEmbedding import BGEM3FlagModel
312318
319+ # TODO(vkovacevic): Issue #804
320+ from transformers import XLMRobertaModel
321+
322+ patch_transformers_for_eager_attn (XLMRobertaModel )
323+
313324 # Configuration
314325 data_format = "float32"
315326 input_sequence_length = 512
Original file line number Diff line number Diff line change @@ -377,3 +377,27 @@ def move_to_cpu(data):
377377 moved = [move_to_cpu (item ) for item in data ]
378378 return type (data )(moved )
379379 return data
380+
381+
382+ # TODO(vkovacevic): Issue #804
383+ def patch_transformers_for_eager_attn (cls ):
384+ """Monkey patch a transformers model class to use eager attention implementation.
385+
386+ This patches the from_pretrained method to inject attn_implementation="eager",
387+ which disables SDPA and other optimized attention implementations that may not
388+ be compatible with certain backends.
389+
390+ Args:
391+ cls: A transformers model class (e.g., ViTForImageClassification, BertModel)
392+ """
393+ from functools import wraps
394+
395+ original_from_pretrained = cls .from_pretrained .__func__
396+
397+ @classmethod
398+ @wraps (original_from_pretrained )
399+ def patched_from_pretrained (cls , * args , ** kwargs ):
400+ kwargs .setdefault ("attn_implementation" , "eager" )
401+ return original_from_pretrained (cls , * args , ** kwargs )
402+
403+ cls .from_pretrained = patched_from_pretrained
Original file line number Diff line number Diff line change 66import os
77
88from benchmark .utils import aggregate_ttnn_perf_metrics , sanitize_filename
9+ from utils import patch_transformers_for_eager_attn
910from vision_benchmark import benchmark_vision_torch_xla
1011
1112# Defaults for all vision models
@@ -229,6 +230,11 @@ def test_unet(output_file):
229230def test_vit (output_file ):
230231 from third_party .tt_forge_models .vit .pytorch .loader import ModelLoader , ModelVariant
231232
233+ # TODO(vkovacevic): Issue #804
234+ from transformers import ViTForImageClassification
235+
236+ patch_transformers_for_eager_attn (ViTForImageClassification )
237+
232238 variant = ModelVariant .BASE
233239 read_logits_fn = lambda output : output .logits
234240 test_vision (
You can’t perform that action at this time.
0 commit comments