1111
1212from models .experimental .functional_bloom .tt import ttnn_functional_bloom
1313from models .experimental .functional_bloom .tt import ttnn_optimized_functional_bloom
14- from models .utility_functions import enable_persistent_kernel_cache , disable_persistent_kernel_cache
1514from models .utility_functions import skip_for_wormhole_b0
1615
1716import ttnn
@@ -87,8 +86,6 @@ def test_performance_of_bloom_for_question_answering(
8786):
8887 torch .manual_seed (0 )
8988
90- enable_persistent_kernel_cache ()
91-
9289 model_name = "bigscience/bloom-560m"
9390 config = BloomConfig .from_pretrained (model_name )
9491 tokenizer = BloomTokenizerFast .from_pretrained (model_name )
@@ -117,6 +114,10 @@ def test_performance_of_bloom_for_question_answering(
117114 input_ids = input_ids , device = device , num_heads = num_heads , attention_mask = attention_mask , max_length = max_length
118115 )
119116
117+ # TODO: don't modify the config globally. Pass it into the functions instead
118+ ttnn_optimized_functional_bloom .BLOOM_MEMORY_CONFIG = ttnn .L1_MEMORY_CONFIG
119+ ttnn_optimized_functional_bloom .ASSUME_FUSED_SOFTMAX = True
120+
120121 # Run twice to measure the time with and without the program cache
121122 for _ in range (2 ):
122123 start = time .time ()
@@ -129,4 +130,6 @@ def test_performance_of_bloom_for_question_answering(
129130 logger .info (f"Duration: { duration } " )
130131 logger .info (f"Samples per second: { 1 / duration * batch_size } " )
131132
132- disable_persistent_kernel_cache ()
133+ # TODO: don't modify the config globally. Pass it into the functions instead
134+ ttnn_optimized_functional_bloom .BLOOM_MEMORY_CONFIG = ttnn .DRAM_MEMORY_CONFIG
135+ ttnn_optimized_functional_bloom .ASSUME_FUSED_SOFTMAX = False
0 commit comments