diff --git a/src/inference_endpoint/cli.py b/src/inference_endpoint/cli.py index f76c7050..e8cbef7e 100644 --- a/src/inference_endpoint/cli.py +++ b/src/inference_endpoint/cli.py @@ -267,8 +267,10 @@ def _add_online_specific_args(parser): ) parser.add_argument( "--concurrency", - type=int, - help="Max concurrent requests (required when --load-pattern=concurrency)", + type=str, + help="Max concurrent requests (required when --load-pattern=concurrency). " + "Can be a single value (e.g., '10') or comma-separated list (e.g., '10,20,30') " + "to run multiple benchmarks sequentially.", ) diff --git a/src/inference_endpoint/commands/benchmark.py b/src/inference_endpoint/commands/benchmark.py index 69b8a20b..b5d3bd1f 100644 --- a/src/inference_endpoint/commands/benchmark.py +++ b/src/inference_endpoint/commands/benchmark.py @@ -273,6 +273,28 @@ def _build_config_from_cli( Raises: InputValidationError: If required params missing """ + # Parse concurrency argument (can be single value or comma-separated list) + concurrency_value: int | list[int] | None = None + if concurrency_str := getattr(args, "concurrency", None): + if isinstance(concurrency_str, int): + concurrency_value = concurrency_str + elif "," in concurrency_str: + # Parse comma-separated list + try: + concurrency_value = [int(c.strip()) for c in concurrency_str.split(",")] + except ValueError as e: + raise InputValidationError( + f"Invalid concurrency value '{concurrency_str}': all values must be integers" + ) from e + else: + # Parse single integer + try: + concurrency_value = int(concurrency_str) + except ValueError as e: + raise InputValidationError( + f"Invalid concurrency value '{concurrency_str}': must be an integer" + ) from e + # Determine load pattern (CLI override or mode default) if load_pattern_arg := getattr(args, "load_pattern", None): load_pattern_type = LoadPatternType(load_pattern_arg) @@ -280,7 +302,7 @@ def _build_config_from_cli( match benchmark_mode: case "offline": load_pattern_type = LoadPatternType.MAX_THROUGHPUT - case "online" if getattr(args, "concurrency", None): + case "online" if concurrency_value is not None: load_pattern_type = LoadPatternType.CONCURRENCY case "online": load_pattern_type = LoadPatternType.POISSON @@ -309,7 +331,7 @@ def _build_config_from_cli( load_pattern=LoadPattern( type=load_pattern_type, target_qps=getattr(args, "target_qps", None), - target_concurrency=getattr(args, "concurrency", None), + target_concurrency=concurrency_value, ), runtime=RuntimeConfig( min_duration_ms=args.duration * 1000 @@ -359,9 +381,90 @@ def _run_benchmark( test_mode: TestMode, benchmark_mode: TestType | None, ) -> None: - """Execute the actual benchmark with full lifecycle management. + """Execute benchmark(s) - either single run or multiple runs for different concurrency values. + + This function handles the top-level orchestration: + 1. If target_concurrency is a single value or not applicable, run one benchmark + 2. If target_concurrency is a list, run multiple benchmarks sequentially + 3. Each benchmark run gets its own subdirectory (concurrency_{value}) + 4. Resources are fully cleaned up between runs to ensure isolation + + Args: + config: Validated BenchmarkConfig (immutable Pydantic model). + collect_responses: Whether to store full response text. + test_mode: What to collect - PERF, ACC, or BOTH. + benchmark_mode: Execution mode - OFFLINE or ONLINE. + """ + # Determine base report directory + if config.report_dir: + base_report_dir = Path(config.report_dir) + else: + base_report_dir = get_default_report_path() + + # Check if we need to run multiple benchmarks for different concurrency values + load_pattern = config.settings.load_pattern + target_concurrency = load_pattern.target_concurrency + + # If concurrency mode with list of values, run multiple benchmarks + if load_pattern.type == LoadPatternType.CONCURRENCY and isinstance( + target_concurrency, list + ): + logger.info( + f"Running {len(target_concurrency)} benchmarks with concurrency values: {target_concurrency}" + ) + + for i, concurrency_val in enumerate(target_concurrency): + logger.info( + f"\n{'='*80}\nBenchmark {i+1}/{len(target_concurrency)}: Concurrency = {concurrency_val}\n{'='*80}" + ) + + # Create subdirectory for this concurrency value + concurrency_report_dir = base_report_dir / f"concurrency_{concurrency_val}" + + try: + _run_single_benchmark( + config=config, + collect_responses=collect_responses, + test_mode=test_mode, + benchmark_mode=benchmark_mode, + report_dir=concurrency_report_dir, + concurrency_value=concurrency_val, + ) + except Exception as e: + logger.error(f"Benchmark failed for concurrency={concurrency_val}: {e}") + raise - This function orchestrates the complete benchmark execution: + logger.info( + f"Completed benchmark {i+1}/{len(target_concurrency)} (concurrency={concurrency_val})" + ) + + logger.info( + f"\n{'='*80}\nAll benchmarks completed successfully!\nResults saved to: {base_report_dir}\n{'='*80}" + ) + + else: + # Single benchmark run + _run_single_benchmark( + config=config, + collect_responses=collect_responses, + test_mode=test_mode, + benchmark_mode=benchmark_mode, + report_dir=base_report_dir, + concurrency_value=None, + ) + + +def _run_single_benchmark( + config: BenchmarkConfig, + collect_responses: bool, + test_mode: TestMode, + benchmark_mode: TestType | None, + report_dir: Path, + concurrency_value: int | None = None, +) -> None: + """Execute a single benchmark run with full lifecycle management. + + This function orchestrates a complete benchmark execution: 1. Load tokenizer for the target model 2. Load and validate dataset using DataLoaderFactory 3. Setup runtime settings and scheduler @@ -370,6 +473,11 @@ def _run_benchmark( 6. Collect and report results 7. Clean up resources (always, even on error) + When called as part of multiple concurrency runs: + - Creates fresh worker processes for isolation + - Clears event hooks to prevent cross-contamination + - Uses dedicated subdirectory for this run's results + Architecture notes: - This is a SYNCHRONOUS function (not async) because HTTPEndpointClient manages its own event loop in a separate thread @@ -382,7 +490,6 @@ def _run_benchmark( - Disabled for offline mode (max throughput focus) Args: - args: Command arguments containing output paths, verbosity, etc. config: Validated BenchmarkConfig (immutable Pydantic model). Contains all benchmark parameters from CLI or YAML. collect_responses: Whether to store full response text. @@ -391,6 +498,9 @@ def _run_benchmark( or BOTH (metrics + responses). benchmark_mode: Execution mode - OFFLINE (max throughput) or ONLINE (sustained QPS). Affects streaming and scheduling. + report_dir: Directory to write reports and results for this run. + concurrency_value: If set, overrides config's target_concurrency for this run. + Used when iterating through multiple concurrency values. Raises: InputValidationError: If model/dataset cannot be loaded or validated. @@ -411,12 +521,29 @@ def _run_benchmark( model_name = config.submission_ref.model config.model_params.name = model_name - if config.report_dir: - report_dir = Path(config.report_dir) - else: - report_dir = get_default_report_path() - + # Ensure report directory exists report_dir.mkdir(parents=True, exist_ok=True) + + # If concurrency_value is provided, create a modified config with single concurrency + if concurrency_value is not None: + # Create a new config with the specific concurrency value + config = BenchmarkConfig( + **{ + **config.model_dump(), + "settings": Settings( + **{ + **config.settings.model_dump(), + "load_pattern": LoadPattern( + **{ + **config.settings.load_pattern.model_dump(), + "target_concurrency": concurrency_value, + } + ), + } + ), + } + ) + config.to_yaml_file(report_dir / "config.yaml") if model_name: @@ -736,6 +863,11 @@ def signal_handler(signum, frame): sample_issuer.shutdown() http_client.shutdown() shutil.rmtree(tmp_dir, ignore_errors=True) + + # Clear event hooks to ensure clean state for next benchmark run + # This is critical when running multiple concurrency benchmarks + SampleEventHandler.clear_hooks() + except Exception as e: if config.verbose: logger.warning(f"Cleanup error: {e}") diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 96ca84c0..0dd83869 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -25,7 +25,7 @@ from typing import ClassVar import yaml -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from .. import metrics from .ruleset_base import BenchmarkSuiteRuleset @@ -266,13 +266,36 @@ class LoadPattern(BaseModel): - max_throughput: target_qps used for calculating total queries (offline, optional with default) - poisson: target_qps sets scheduler rate (online, required - validated) - concurrency: issue at fixed target_concurrency (online, required - validated) + + target_concurrency can be either: + - Single int: Run one benchmark with that concurrency level + - List of ints: Run multiple benchmarks sequentially, one per concurrency level """ type: LoadPatternType = LoadPatternType.MAX_THROUGHPUT target_qps: float | None = ( None # Target QPS - required for poisson pattern, optional otherwise ) - target_concurrency: int | None = None # For concurrency mode, ignored otherwise + target_concurrency: int | list[int] | None = ( + None # For concurrency mode, ignored otherwise + ) + + @field_validator("target_concurrency", mode="before") + @classmethod + def validate_target_concurrency(cls, v): + """Validate target_concurrency accepts int or list of ints.""" + if v is None: + return v + # Accept single int + if isinstance(v, int): + return v + # Accept list of ints + if isinstance(v, list): + return v + # Try to convert if it's something else (shouldn't happen with proper YAML) + raise ValueError( + f"target_concurrency must be an integer or list of integers, got {type(v)}" + ) class ClientSettings(BaseModel): @@ -524,10 +547,25 @@ def validate_load_pattern(self, benchmark_mode: TestType) -> None: ) elif load_pattern_type == LoadPatternType.CONCURRENCY: # Concurrency pattern requires target_concurrency > 0 - if not target_concurrency or target_concurrency <= 0: + # Can be single int or list of ints + if target_concurrency is None: + raise ValueError( + "Concurrency load pattern requires target_concurrency to be specified. " + "Specify number of concurrent requests (e.g., target_concurrency: 10 or target_concurrency: [10, 20, 30] in YAML or --concurrency 10 in CLI)" + ) + + # Validate single int or list of ints + if isinstance(target_concurrency, list): + if len(target_concurrency) == 0: + raise ValueError("target_concurrency list cannot be empty") + for i, conc in enumerate(target_concurrency): + if not isinstance(conc, int) or conc <= 0: + raise ValueError( + f"target_concurrency[{i}] must be a positive integer, got {conc}" + ) + elif not isinstance(target_concurrency, int) or target_concurrency <= 0: raise ValueError( - "Concurrency load pattern requires target_concurrency > 0. " - "Specify number of concurrent requests (e.g., target_concurrency: 10 under load_pattern in YAML or --concurrency 10 in CLI)" + f"target_concurrency must be a positive integer or list of positive integers, got {target_concurrency}" ) def validate_client_settings(self) -> None: diff --git a/src/inference_endpoint/load_generator/scheduler.py b/src/inference_endpoint/load_generator/scheduler.py index ae691d09..d8f2c120 100644 --- a/src/inference_endpoint/load_generator/scheduler.py +++ b/src/inference_endpoint/load_generator/scheduler.py @@ -379,7 +379,11 @@ def __init__(self, runtime_settings: RuntimeSettings, sample_order_cls): super().__init__(runtime_settings, sample_order_cls) assert runtime_settings.load_pattern is not None target_concurrency = runtime_settings.load_pattern.target_concurrency - if target_concurrency is None or target_concurrency <= 0: + if ( + target_concurrency is None + or (isinstance(target_concurrency, int) and target_concurrency <= 0) + or (isinstance(target_concurrency, list) and len(target_concurrency) == 0) + ): raise ValueError( f"target_concurrency must be > 0 for CONCURRENCY load pattern, got {target_concurrency}" )