Skip to content

Commit 6344b6a

Browse files
committed
Add attn_implementation="eager" to ViT and BERT
1 parent d245d40 commit 6344b6a

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

benchmark/tt-xla/test_encoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def test_bert(output_file):
158158
loader = ModelLoader()
159159
model_info_name = loader.get_model_info().name
160160
print(f"\nLoading model {model_info_name}...")
161-
model = loader.load_model(dtype_override=DTYPE_MAP[data_format])
161+
model = loader.load_model(dtype_override=DTYPE_MAP[data_format], attn_implementation="eager")
162162

163163
# Create function for loading raw inputs
164164
load_inputs_fn = get_default_inputs

benchmark/tt-xla/test_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def test_vit(output_file):
402402
variant = ModelVariant.BASE
403403
loader = ModelLoader(variant=variant)
404404
model_info_name = loader.get_model_info(variant=variant).name
405-
model = loader.load_model(dtype_override=data_format)
405+
model = loader.load_model(dtype_override=data_format, attn_implementation="eager")
406406
model = model.eval()
407407

408408
def load_inputs_fn(batch_size, dtype):

third_party/tt_forge_models

Submodule tt_forge_models updated 263 files

0 commit comments

Comments
 (0)