Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions benchmark/tt-xla/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ def pytest_addoption(parser):
type=make_validator_boolean("--experimental-compile"),
help="Enable experimental compile flag (true/false). Overrides config value.",
)
parser.addoption(
"--profile",
action="store_true",
default=False,
help="Enable profiling mode: uses single layer, minimal iterations, and tracy signposts.",
)


@pytest.fixture
Expand Down Expand Up @@ -217,3 +223,8 @@ def task(request):
@pytest.fixture
def experimental_compile(request):
return request.config.getoption("--experimental-compile")


@pytest.fixture
def profile(request):
return request.config.getoption("--profile")
13 changes: 12 additions & 1 deletion benchmark/tt-xla/llm_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from transformers.cache_utils import StaticCache
from transformers.modeling_outputs import CausalLMOutputWithPast
import tracy

from benchmark.utils import get_xla_device_arch
from utils import (
Expand Down Expand Up @@ -193,6 +194,7 @@ def generate_and_benchmark(
iteration_times: List[float] = []
with torch.no_grad():
for step in range(max_tokens_to_generate):
tracy.signpost("token_generation_start")
start = time.perf_counter_ns()

# Run forward pass
Expand Down Expand Up @@ -222,6 +224,8 @@ def generate_and_benchmark(
input_args["cache_position"] = host_cache_pos.to(device)

end = time.perf_counter_ns()
tracy.signpost("token_generation_end")

iteration_times.append(end - start)
if verbose:
print(f"Iteration\t{step}/{max_tokens_to_generate}\ttook {iteration_times[-1] / 1e6:.04} ms")
Expand Down Expand Up @@ -268,6 +272,7 @@ def benchmark_llm_torch_xla(
shard_spec_fn,
arch,
required_pcc,
profile=False,
):
"""
Benchmark an LLM (Large Language Model) using PyTorch and torch-xla.
Expand Down Expand Up @@ -352,6 +357,10 @@ def benchmark_llm_torch_xla(
# Limit maximum generation count to fit within preallocated static cache
max_tokens_to_generate: int = max_cache_len - input_args["input_ids"].shape[1]

# In profile mode, limit tokens to 2 for faster profiling
if profile:
max_tokens_to_generate = 2

# Get CPU result
cpu_logits, _ = generate_and_benchmark(
model,
Expand Down Expand Up @@ -423,6 +432,8 @@ def benchmark_llm_torch_xla(
mesh=mesh,
)

tracy.signpost("warmup_complete")

# Reconstruct inputs for the actual benchmark run
input_args = construct_inputs(
tokenizer, model.config, batch_size, max_cache_len, past_key_values=input_args["past_key_values"]
Expand All @@ -443,7 +454,7 @@ def benchmark_llm_torch_xla(
mesh=mesh,
)

if len(iteration_times) < 10:
if not profile and len(iteration_times) < 10:
raise RuntimeError("LLM benchmark failed: insufficient number of iterations completed.")

ttft_ns = iteration_times[0]
Expand Down
Loading
Loading