2424from transformers import AutoModelForCausalLM , AutoTokenizer , PreTrainedTokenizer
2525from transformers .cache_utils import StaticCache
2626from transformers .modeling_outputs import CausalLMOutputWithPast
27+ import tracy
2728
2829from benchmark .utils import get_xla_device_arch
2930from utils import (
@@ -193,6 +194,7 @@ def generate_and_benchmark(
193194 iteration_times : List [float ] = []
194195 with torch .no_grad ():
195196 for step in range (max_tokens_to_generate ):
197+ tracy .signpost ("token_generation_start" )
196198 start = time .perf_counter_ns ()
197199
198200 # Run forward pass
@@ -222,6 +224,8 @@ def generate_and_benchmark(
222224 input_args ["cache_position" ] = host_cache_pos .to (device )
223225
224226 end = time .perf_counter_ns ()
227+ tracy .signpost ("token_generation_end" )
228+
225229 iteration_times .append (end - start )
226230 if verbose :
227231 print (f"Iteration\t { step } /{ max_tokens_to_generate } \t took { iteration_times [- 1 ] / 1e6 :.04} ms" )
@@ -268,6 +272,7 @@ def benchmark_llm_torch_xla(
268272 shard_spec_fn ,
269273 arch ,
270274 required_pcc ,
275+ profile = False ,
271276):
272277 """
273278 Benchmark an LLM (Large Language Model) using PyTorch and torch-xla.
@@ -352,6 +357,10 @@ def benchmark_llm_torch_xla(
352357 # Limit maximum generation count to fit within preallocated static cache
353358 max_tokens_to_generate : int = max_cache_len - input_args ["input_ids" ].shape [1 ]
354359
360+ # In profile mode, limit tokens to 2 for faster profiling
361+ if profile :
362+ max_tokens_to_generate = 2
363+
355364 # Get CPU result
356365 cpu_logits , _ = generate_and_benchmark (
357366 model ,
@@ -423,6 +432,8 @@ def benchmark_llm_torch_xla(
423432 mesh = mesh ,
424433 )
425434
435+ tracy .signpost ("warmup_complete" )
436+
426437 # Reconstruct inputs for the actual benchmark run
427438 input_args = construct_inputs (
428439 tokenizer , model .config , batch_size , max_cache_len , past_key_values = input_args ["past_key_values" ]
@@ -443,7 +454,7 @@ def benchmark_llm_torch_xla(
443454 mesh = mesh ,
444455 )
445456
446- if len (iteration_times ) < 10 :
457+ if not profile and len (iteration_times ) < 10 :
447458 raise RuntimeError ("LLM benchmark failed: insufficient number of iterations completed." )
448459
449460 ttft_ns = iteration_times [0 ]
0 commit comments