1515"""TensorRT LLM perf sanity tests."""
1616
1717import copy
18+ import fcntl
1819import glob
1920import os
2021import re
22+ import shutil
2123import socket
2224import subprocess
2325import time
6062 "H200" : "h200" ,
6163}
6264
65+ BENCH_SERVING_REPO = "https://github.com/kedarpotdar-nv/bench_serving.git"
66+ BENCH_SERVING_COMMIT = "f3ea022a5780de5d0babc5fffa53634e2023d28f"
67+ BENCH_SERVING_DIR = "/tmp/bench_serving"
68+
69+
70+ def ensure_bench_serving_repo () -> str :
71+ """Clone bench_serving repo if not already present. Returns path to benchmark_serving.py.
72+
73+ Uses a file lock to avoid race conditions when multiple ranks within the
74+ same container simultaneously attempt to clone the repository.
75+ """
76+ bench_script = os .path .join (BENCH_SERVING_DIR , "benchmark_serving.py" )
77+ lock_file = BENCH_SERVING_DIR + ".lock"
78+
79+ with open (lock_file , "w" ) as lf :
80+ fcntl .flock (lf , fcntl .LOCK_EX )
81+ try :
82+ if not os .path .exists (bench_script ):
83+ if os .path .exists (BENCH_SERVING_DIR ):
84+ shutil .rmtree (BENCH_SERVING_DIR )
85+ subprocess .check_call (
86+ ["git" , "clone" , "--depth" , "1" , BENCH_SERVING_REPO , BENCH_SERVING_DIR ]
87+ )
88+ subprocess .check_call (
89+ [
90+ "git" ,
91+ "-C" ,
92+ BENCH_SERVING_DIR ,
93+ "fetch" ,
94+ "--depth" ,
95+ "1" ,
96+ "origin" ,
97+ BENCH_SERVING_COMMIT ,
98+ ]
99+ )
100+ subprocess .check_call (
101+ ["git" , "-C" , BENCH_SERVING_DIR , "checkout" , BENCH_SERVING_COMMIT ]
102+ )
103+ finally :
104+ fcntl .flock (lf , fcntl .LOCK_UN )
105+
106+ return bench_script
107+
108+
63109DEFAULT_TIMEOUT = 5400
64110AGG_CONFIG_FOLDER = os .environ .get ("AGG_CONFIG_FOLDER" , "tests/scripts/perf-sanity/aggregated" )
65111DISAGG_CONFIG_FOLDER = os .environ .get (
@@ -439,6 +485,7 @@ def __init__(
439485 self .trust_remote_code = client_config_data .get ("trust_remote_code" , True )
440486 self .model_path = ""
441487 self .dataset_file = client_config_data .get ("dataset_file" , "" )
488+ self .use_nv_sa_benchmark = client_config_data .get ("use_nv_sa_benchmark" , False )
442489 self .env_vars = env_vars
443490
444491 # Generate default name if not provided
@@ -450,6 +497,48 @@ def to_cmd(self) -> List[str]:
450497 """Generate benchmark command."""
451498 model_dir = get_model_dir (self .model_name )
452499 self .model_path = model_dir if os .path .exists (model_dir ) else self .model_name
500+
501+ if self .use_nv_sa_benchmark :
502+ return self ._to_sa_benchmark_cmd ()
503+ else :
504+ return self ._to_default_benchmark_cmd ()
505+
506+ def _to_sa_benchmark_cmd (self ) -> List [str ]:
507+ """Generate SA benchmark command (bench_serving repo)."""
508+ bench_script = ensure_bench_serving_repo ()
509+ benchmark_cmd = [
510+ "python" ,
511+ bench_script ,
512+ "--model" ,
513+ self .model_path ,
514+ "--dataset-name" ,
515+ "random" ,
516+ "--num-prompts" ,
517+ str (self .concurrency * self .iterations ),
518+ "--max-concurrency" ,
519+ str (self .concurrency ),
520+ "--ignore-eos" ,
521+ "--random-input-len" ,
522+ str (self .isl ),
523+ "--random-output-len" ,
524+ str (self .osl ),
525+ "--random-range-ratio" ,
526+ str (self .random_range_ratio ),
527+ "--save-result" ,
528+ "--percentile-metrics" ,
529+ "ttft,tpot,itl,e2el" ,
530+ ]
531+ if self .backend :
532+ benchmark_cmd .extend (["--backend" , self .backend ])
533+ if self .trust_remote_code :
534+ benchmark_cmd .append ("--trust-remote-code" )
535+ if self .use_chat_template :
536+ benchmark_cmd .append ("--use-chat-template" )
537+ # Note: bench_serving has no --non-streaming flag; streaming is backend-determined
538+ return benchmark_cmd
539+
540+ def _to_default_benchmark_cmd (self ) -> List [str ]:
541+ """Generate default benchmark command (tensorrt_llm benchmark_serving)."""
453542 dataset_path = get_dataset_dir (self .dataset_file )
454543 benchmark_cmd = [
455544 "python" ,
@@ -513,6 +602,7 @@ def to_match_keys(self) -> List[str]:
513602 "s_backend" ,
514603 "b_use_chat_template" ,
515604 "b_streaming" ,
605+ "b_use_nv_sa_benchmark" ,
516606 ]
517607
518608 def to_db_data (self ) -> dict :
@@ -529,6 +619,7 @@ def to_db_data(self) -> dict:
529619 "b_use_chat_template" : self .use_chat_template ,
530620 "b_streaming" : self .streaming ,
531621 "b_trust_remote_code" : self .trust_remote_code ,
622+ "b_use_nv_sa_benchmark" : self .use_nv_sa_benchmark ,
532623 "s_client_log_link" : "" ,
533624 "s_client_env_vars" : self .env_vars ,
534625 }
@@ -1292,6 +1383,7 @@ def _parse_disagg_config_file(self, config_file_path: str, config_file: str):
12921383 # For ctx_only: OSL is set to 1 and dataset_file is empty
12931384 osl = 1 if benchmark_mode == "ctx_only" else benchmark .get ("output_length" , 1024 )
12941385 dataset_file = "" if benchmark_mode == "ctx_only" else benchmark .get ("dataset_file" , "" )
1386+ use_nv_sa_benchmark = benchmark .get ("use_nv_sa_benchmark" , False )
12951387
12961388 client_configs = []
12971389 for concurrency in concurrency_values :
@@ -1305,6 +1397,7 @@ def _parse_disagg_config_file(self, config_file_path: str, config_file: str):
13051397 "use_chat_template" : False ,
13061398 "streaming" : benchmark .get ("streaming" , True ),
13071399 "dataset_file" : dataset_file ,
1400+ "use_nv_sa_benchmark" : use_nv_sa_benchmark ,
13081401 }
13091402 client_config = ClientConfig (
13101403 client_config_data ,
@@ -1426,19 +1519,36 @@ def _check_benchmark_errors(self, output: str) -> None:
14261519 if not output :
14271520 return
14281521
1429- # Check for non-zero failed requests
1522+ # Check for non-zero failed requests (default benchmark)
14301523 failed_requests_match = re .search (r"Failed requests:\s+(\d+)" , output )
14311524 if failed_requests_match :
14321525 failed_count = int (failed_requests_match .group (1 ))
14331526 if failed_count > 0 :
14341527 error_msg = f"Benchmark output contains { failed_count } failed requests."
14351528 raise RuntimeError (error_msg )
14361529
1437- # Check for explicit failure markers
1530+ # Check for explicit failure markers (default benchmark)
14381531 if "!FAILED REQUESTS!" in output or "!CHECK LOG FOR ERRORS!" in output :
14391532 error_msg = "Benchmark output contains failure markers."
14401533 raise RuntimeError (error_msg )
14411534
1535+ # SA benchmark (bench_serving) only prints "Successful requests:"
1536+ # without "Failed requests:". Detect failures by comparing successful
1537+ # count against num_prompts from the Namespace output.
1538+ if not failed_requests_match :
1539+ successful_match = re .search (r"Successful requests:\s+(\d+)" , output )
1540+ num_prompts_match = re .search (r"num_prompts=(\d+)" , output )
1541+ if successful_match and num_prompts_match :
1542+ successful_count = int (successful_match .group (1 ))
1543+ num_prompts = int (num_prompts_match .group (1 ))
1544+ failed_count = num_prompts - successful_count
1545+ if failed_count > 0 :
1546+ error_msg = (
1547+ f"SA benchmark: { failed_count } of { num_prompts } requests failed "
1548+ f"({ successful_count } successful)."
1549+ )
1550+ raise RuntimeError (error_msg )
1551+
14421552 def run_ex (self , commands ) -> Dict [int , List [str ]]:
14431553 """Run commands and collect outputs."""
14441554 outputs = {}
@@ -1478,8 +1588,17 @@ def parse_metrics_from_output(output: str) -> Optional[Dict[str, float]]:
14781588 for server_idx , client_configs in self .server_client_configs .items ():
14791589 self ._perf_results [server_idx ] = []
14801590 server_outputs = outputs .get (server_idx , [])
1481- for output in server_outputs :
1591+ for client_idx , output in enumerate ( server_outputs ) :
14821592 metrics = parse_metrics_from_output (output )
1593+ # SA benchmark (bench_serving) doesn't report user_throughput.
1594+ # Use None as sentinel to distinguish "not available" from actual zero.
1595+ if (
1596+ metrics
1597+ and "user_throughput" not in metrics
1598+ and client_idx < len (client_configs )
1599+ and client_configs [client_idx ].use_nv_sa_benchmark
1600+ ):
1601+ metrics ["user_throughput" ] = None
14831602 self ._perf_results [server_idx ].append (metrics )
14841603
14851604 def check_test_failure (self ):
0 commit comments