From b663836d288fb095bc249fa15aea2ba421ad5a1f Mon Sep 17 00:00:00 2001 From: lowdy1 Date: Fri, 24 Apr 2026 08:52:41 +0000 Subject: [PATCH 1/4] benchmark refactor example --- benchmark/scripts/benchmark_layer_norm.py | 193 ++++++----------- benchmark/scripts/benchmark_model_configs.py | 87 +++++++- benchmark/scripts/benchmark_swiglu.py | 214 +++++++------------ 3 files changed, 237 insertions(+), 257 deletions(-) diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 456e62b69..369a7d360 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -3,9 +3,8 @@ import torch from benchmark_model_configs import MODEL_REGISTRY -from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import build_model_config_sweep +from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -23,19 +22,28 @@ def _setup_layer_norm(input: SingleBenchmarkRunInput): """Create input tensor and LayerNorm layer from benchmark config.""" cfg = input.extra_benchmark_config - hidden_size = cfg["hidden_size"] + if isinstance(input.x, str): + model_cfg = MODEL_REGISTRY[input.x] + seq_len = cfg["seq_len"] + hidden_size = model_cfg.hidden_size + dtype = model_cfg.dtype + else: + seq_len = input.x + hidden_size = cfg["hidden_size"] + dtype = cfg["dtype"] + eps = cfg["eps"] x = torch.randn( - input.x, + seq_len, hidden_size, device=device, - dtype=cfg["dtype"], + dtype=dtype, requires_grad=True, ) if input.kernel_provider == "liger": - layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device) + layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device).to(dtype) elif input.kernel_provider == "huggingface": - layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device) + layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device).to(dtype) else: raise ValueError(f"Invalid provider: {input.kernel_provider} for LayerNorm") return x, layer @@ -51,100 +59,41 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) -def _resolve_model_config_layer_norm(input: SingleBenchmarkRunInput): - """Resolve model-config-sweep input into standard setup args.""" - cfg = input.extra_benchmark_config - model_info = cfg["model_configs"][input.x] - return _setup_layer_norm( - SingleBenchmarkRunInput( - x=cfg["BT"], - kernel_provider=input.kernel_provider, - extra_benchmark_config={ - "hidden_size": model_info["hidden_size"], - "dtype": model_info["dtype"], - "eps": cfg["eps"], - }, - ) - ) - - -def bench_speed_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_layer_norm(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_layer_norm(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - if __name__ == "__main__": args = parse_benchmark_script_args() if args.sweep_mode == "model_config": all_model_configs = list(MODEL_REGISTRY.values()) - def _probe_factory(model_cfg, probe_bt): - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_bt, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model_cfg.hidden_size, - "dtype": model_cfg.dtype, - "eps": 1e-6, - }, - ) - x, layer = _setup_layer_norm(probe_input) - return layer(x) - - return _probe - - sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) - - model_configs_info = { - cfg.name: { - "hidden_size": cfg.hidden_size, - "dtype": cfg.dtype, - } - for cfg in sweep.model_configs - } - - common_configs = { - "kernel_name": "layer_norm", - "x_name": "model_config", - "x_label": "model configuration", - "x_values": [cfg.name for cfg in sweep.model_configs], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "model_configs": model_configs_info, - "BT": sweep.bt, + def probe_fn(model_cfg, probe_bt): + probe_input = SingleBenchmarkRunInput( + x=probe_bt, + kernel_provider="huggingface", + extra_benchmark_config={ + "hidden_size": model_cfg.hidden_size, + "dtype": model_cfg.dtype, "eps": 1e-6, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_layer_norm_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_layer_norm_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, + }, + ) + x, layer = _setup_layer_norm(probe_input) + return layer(x) + + common_configs = build_model_config_sweep( + kernel_name="layer_norm", + all_model_configs=all_model_configs, + probe_fn=probe_fn, + extra_benchmark_config={ + "eps": 1e-6, + }, + bt=args.bt, + overwrite=args.overwrite, ) + else: model = get_benchmark_model_config(args.model) probe_bt = 1024 - def _probe(): + def probe_fn(): probe_input = SingleBenchmarkRunInput( x=probe_bt, kernel_provider="huggingface", @@ -157,38 +106,36 @@ def _probe(): x, layer = _setup_layer_norm(probe_input) return layer(x) - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_bt + def x_values_fn(config): + return [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)] + + common_configs = build_token_length_sweep( + kernel_name="layer_norm", + probe_seq_len=probe_bt, + model=model, + probe_fn=probe_fn, + extra_config_fn={ + "hidden_size": model.hidden_size, + "dtype": model.dtype, + "eps": 1e-6, + }, + x_values_fn=x_values_fn, + overwrite=args.overwrite, + ) - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + common_configs["kernel_providers"] = ["liger", "huggingface"] - common_configs = { - "kernel_name": "layer_norm", - "x_name": "BT", - "x_label": "B * T", - "x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "hidden_size": model.hidden_size, - "dtype": model.dtype, - "eps": 1e-6, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_layer_norm, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_layer_norm, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, - ) + run_benchmarks( + bench_test_fn=bench_speed_layer_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_layer_norm, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index b7137942b..aebfe665a 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -341,7 +341,7 @@ def compute_seq_len_sweep_config( def compute_model_config_sweep_config( model_configs: List[ModelConfig], probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]], - bt: int = 2048, + bt: int = 1024, memory_utilization: float = 0.4, ) -> ModelConfigSweepConfig: """Find safe (batch_size, seq_len) that works across all model configs. @@ -383,3 +383,88 @@ def compute_model_config_sweep_config( batch_size=batch_size, seq_len=seq_len, ) + + +def build_model_config_sweep( + kernel_name: str, + all_model_configs: List[ModelConfig], + probe_fn: Callable[[ModelConfig, int], torch.Tensor], + extra_benchmark_config: Dict, + bt: int = 2048, + overwrite: bool = False, +) -> Dict: + """Build benchmark config dict for model-config sweep. + + Returns a single extra_benchmark_config (with a model_configs lookup dict + and seq_len filled in from the sweep), so run_benchmarks iterates over + model names as x_values rather than repeating per-model configs. + + Args: + kernel_name: Name of the kernel being benchmarked. + all_model_configs: List of model configs to sweep over. + probe_fn: Callable(model_cfg, probe_seq_len) -> output tensor for memory estimation. + extra_benchmark_config: Base config dict. seq_len will be set from sweep result. + bt: Target total tokens (batch_size * seq_len) used as probe bt. + overwrite: Whether to overwrite existing benchmark data. + """ + + def probe_fn_factory(model_cfg, probe_seq_len): + return lambda: probe_fn(model_cfg, probe_seq_len) + + sweep = compute_model_config_sweep_config( + all_model_configs, + probe_fn_factory=probe_fn_factory, + bt=bt, + ) + + config = {**extra_benchmark_config, "bsz": sweep.batch_size, "seq_len": sweep.seq_len} + + return { + "kernel_name": kernel_name, + "x_name": "model_config", + "x_label": "model configuration", + "x_values": [cfg.name for cfg in sweep.model_configs], + "extra_benchmark_configs": [config], + "overwrite": overwrite, + } + + +def build_token_length_sweep( + kernel_name: str, + probe_seq_len: int, + model: ModelConfig, + probe_fn: Callable[[], torch.Tensor], + extra_config_fn: Callable[[SeqLenSweepConfig], Dict] | Dict, + x_values_fn: Callable[[SeqLenSweepConfig], List[int]], + overwrite: bool = False, +) -> Dict: + """Build benchmark config dict for token-length sweep. + + Args: + kernel_name: Name of the kernel being benchmarked. + model: Model config to use for the sweep. + probe_fn: Callable() -> output tensor for memory estimation. + extra_config_fn: Callable(config) -> dict with normalized keys + that _setup_* expects. + x_values_fn: Callable(config) -> list of sequence lengths to benchmark. + overwrite: Whether to overwrite existing benchmark data. + + Returns: + Dict with keys: kernel_name, x_name, x_label, x_values, kernel_providers, + extra_benchmark_configs, overwrite. + """ + peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn) + kernel_bpt = peak_bytes // probe_seq_len + + config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + + return { + "kernel_name": kernel_name, + "x_name": "T", + "x_label": "sequence length", + "x_values": x_values_fn(config), + "extra_benchmark_configs": [extra_config_fn] + if isinstance(extra_config_fn, dict) + else [extra_config_fn(config)], + "overwrite": overwrite, + } diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index dc34fd60d..8ff88c4d2 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -3,9 +3,8 @@ import torch from benchmark_model_configs import MODEL_REGISTRY -from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import build_model_config_sweep +from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP @@ -25,23 +24,35 @@ def _setup_swiglu(input: SingleBenchmarkRunInput): """Create input tensor and SwiGLU layer from benchmark config.""" cfg = input.extra_benchmark_config + if isinstance(input.x, str): + model_cfg = MODEL_REGISTRY[input.x] + seq_len = cfg["seq_len"] + hidden_size = model_cfg.hidden_size + intermediate_size = model_cfg.intermediate_size + dtype = model_cfg.dtype + else: + seq_len = input.x + hidden_size = cfg["hidden_size"] + intermediate_size = cfg["intermediate_size"] + dtype = cfg["dtype"] + llama_config = LlamaConfig( - hidden_size=cfg["hidden_size"], - intermediate_size=cfg["intermediate_size"], + hidden_size=hidden_size, + intermediate_size=intermediate_size, hidden_act=cfg["hidden_act"], ) x = torch.randn( cfg["bsz"], - input.x, - cfg["hidden_size"], + seq_len, + hidden_size, device=device, - dtype=cfg["dtype"], + dtype=dtype, requires_grad=True, ) if input.kernel_provider == "liger": - layer = LigerSwiGLUMLP(config=llama_config).to(device).to(cfg["dtype"]) + layer = LigerSwiGLUMLP(config=llama_config).to(device).to(dtype) elif input.kernel_provider == "huggingface": - layer = LlamaMLP(config=llama_config).to(device).to(cfg["dtype"]) + layer = LlamaMLP(config=llama_config).to(device).to(dtype) else: raise ValueError(f"Invalid provider: {input.kernel_provider} for SwiGLU") return x, layer @@ -57,106 +68,42 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) -def _resolve_model_config_swiglu(input: SingleBenchmarkRunInput): - """Resolve model-config-sweep input into standard setup args.""" - cfg = input.extra_benchmark_config - model_info = cfg["model_configs"][input.x] - return _setup_swiglu( - SingleBenchmarkRunInput( - x=cfg["seq_len"], - kernel_provider=input.kernel_provider, - extra_benchmark_config={ - "bsz": cfg["bsz"], - "hidden_size": model_info["hidden_size"], - "intermediate_size": model_info["intermediate_size"], - "hidden_act": cfg["hidden_act"], - "dtype": model_info["dtype"], - }, - ) - ) - - -def bench_speed_swiglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_swiglu(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_swiglu_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _resolve_model_config_swiglu(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - if __name__ == "__main__": args = parse_benchmark_script_args() if args.sweep_mode == "model_config": all_model_configs = list(MODEL_REGISTRY.values()) - def _probe_factory(model_cfg, probe_seq_len): - def _probe(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", - extra_benchmark_config={ - "bsz": 1, - "hidden_size": model_cfg.hidden_size, - "intermediate_size": model_cfg.intermediate_size, - "hidden_act": "silu", - "dtype": model_cfg.dtype, - }, - ) - x, layer = _setup_swiglu(probe_input) - return layer(x) - - return _probe - - sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) - - model_configs_info = { - cfg.name: { - "hidden_size": cfg.hidden_size, - "intermediate_size": cfg.intermediate_size, - "dtype": cfg.dtype, - } - for cfg in sweep.model_configs - } - - common_configs = { - "kernel_name": "swiglu", - "x_name": "model_config", - "x_label": "model configuration", - "x_values": [cfg.name for cfg in sweep.model_configs], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "model_configs": model_configs_info, - "bsz": sweep.batch_size, - "seq_len": sweep.seq_len, + def probe_fn(model_cfg, probe_seq_len): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider="huggingface", + extra_benchmark_config={ + "bsz": 1, "hidden_act": "silu", - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_swiglu_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_swiglu_model_config, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, + "hidden_size": model_cfg.hidden_size, + "intermediate_size": model_cfg.intermediate_size, + "dtype": model_cfg.dtype, + }, + ) + x, layer = _setup_swiglu(probe_input) + return layer(x) + + common_configs = build_model_config_sweep( + kernel_name="swiglu", + all_model_configs=all_model_configs, + probe_fn=probe_fn, + extra_benchmark_config={ + "hidden_act": "silu", + }, + bt=args.bt, + overwrite=args.overwrite, ) else: model = get_benchmark_model_config(args.model) probe_seq_len = 1024 - def _probe(): + def probe_fn(): probe_input = SingleBenchmarkRunInput( x=probe_seq_len, kernel_provider="huggingface", @@ -171,40 +118,41 @@ def _probe(): x, layer = _setup_swiglu(probe_input) return layer(x) - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_seq_len - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + def extra_config_fn(config): + return { + "bsz": config.batch_size, + "hidden_size": model.hidden_size, + "intermediate_size": model.intermediate_size, + "hidden_act": "silu", + "dtype": model.dtype, + } - common_configs = { - "kernel_name": "swiglu", - "x_name": "T", - "x_label": "sequence length", - "x_values": [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "bsz": config.batch_size, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_swiglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_swiglu, - kernel_operation_modes=["full", "forward", "backward"], - metric_name="memory", - metric_unit="MB", - **common_configs, + def x_values_fn(config): + return [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)] + + common_configs = build_token_length_sweep( + kernel_name="swiglu", + probe_seq_len=probe_seq_len, + model=model, + probe_fn=probe_fn, + extra_config_fn=extra_config_fn, + x_values_fn=x_values_fn, + overwrite=args.overwrite, ) + + common_configs["kernel_providers"] = ["liger", "huggingface"] + + run_benchmarks( + bench_test_fn=bench_speed_swiglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=bench_memory_swiglu, + kernel_operation_modes=["full", "forward", "backward"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) From a04ea3ca112cebc7797c3bde179a9eda955bd83f Mon Sep 17 00:00:00 2001 From: lowdy1 Date: Sat, 25 Apr 2026 07:59:15 +0000 Subject: [PATCH 2/4] move probe into builder --- benchmark/README.md | 3 +- benchmark/scripts/benchmark_layer_norm.py | 50 ++----- benchmark/scripts/benchmark_model_configs.py | 149 +++++++++++++++---- benchmark/scripts/benchmark_swiglu.py | 58 ++------ 4 files changed, 142 insertions(+), 118 deletions(-) diff --git a/benchmark/README.md b/benchmark/README.md index aefba404b..42f20eb62 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -12,7 +12,8 @@ Follow these steps to benchmark and visualize kernel performance: Example: Benchmarking KTO Loss ```bash cd benchmark - python scripts/benchmark_kto_loss.py + python scripts/benchmark_kto_loss.py --sweep-mode model_config [--model llama_3_8b] + python scripts/benchmark_kto_loss.py [--sweep-mode token_length] [--bt 2048] ``` 3. Visualize results diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 369a7d360..1686daa7c 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -1,5 +1,3 @@ -import math - import torch from benchmark_model_configs import MODEL_REGISTRY @@ -65,61 +63,33 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu if args.sweep_mode == "model_config": all_model_configs = list(MODEL_REGISTRY.values()) - def probe_fn(model_cfg, probe_bt): - probe_input = SingleBenchmarkRunInput( - x=probe_bt, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model_cfg.hidden_size, - "dtype": model_cfg.dtype, - "eps": 1e-6, - }, - ) - x, layer = _setup_layer_norm(probe_input) - return layer(x) - common_configs = build_model_config_sweep( kernel_name="layer_norm", all_model_configs=all_model_configs, - probe_fn=probe_fn, - extra_benchmark_config={ + setup_fn=_setup_layer_norm, + model_keys=["hidden_size", "dtype"], + extra_configs={ "eps": 1e-6, }, + probe_provider="huggingface", bt=args.bt, overwrite=args.overwrite, ) else: model = get_benchmark_model_config(args.model) - probe_bt = 1024 - - def probe_fn(): - probe_input = SingleBenchmarkRunInput( - x=probe_bt, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model.hidden_size, - "dtype": model.dtype, - "eps": 1e-6, - }, - ) - x, layer = _setup_layer_norm(probe_input) - return layer(x) - - def x_values_fn(config): - return [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)] + probe_seq_len = 1024 common_configs = build_token_length_sweep( kernel_name="layer_norm", - probe_seq_len=probe_bt, + probe_seq_len=probe_seq_len, model=model, - probe_fn=probe_fn, - extra_config_fn={ - "hidden_size": model.hidden_size, - "dtype": model.dtype, + setup_fn=_setup_layer_norm, + model_keys=["hidden_size", "dtype"], + extra_configs={ "eps": 1e-6, }, - x_values_fn=x_values_fn, + probe_provider="huggingface", overwrite=args.overwrite, ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index aebfe665a..6f96a7303 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -26,6 +26,7 @@ import math from dataclasses import dataclass +from typing import Any from typing import Callable from typing import Dict from typing import List @@ -34,6 +35,8 @@ import torch +from utils import SingleBenchmarkRunInput + from liger_kernel.utils import get_total_gpu_memory from liger_kernel.utils import infer_device @@ -385,31 +388,86 @@ def compute_model_config_sweep_config( ) +def build_extra_config( + model: ModelConfig, + model_keys: List[str], + extra_configs: Optional[Dict] = None, +) -> Dict: + """Construct extra_benchmark_config dict. + + Args: + model: The model configuration object. + model_keys: List of attribute names to read from `model` + (e.g. ["hidden_size", "dtype"]). + extra_configs: Optional dictionary of additional key/value pairs + that override or extend the extracted attributes. + """ + extra_configs = extra_configs or {} + cfg = {k: getattr(model, k) for k in model_keys} + cfg.update(extra_configs) + return cfg + + +def _default_forward_fn(x, layer): + """Default forward function for common (input, module) patterns. + + Assumes `setup_fn` returns `(x, layer)` and simply applies: + layer(x) + + This covers the majority of kernels that follow a + "tensor + nn.Module" execution pattern. + """ + return layer(x) + + def build_model_config_sweep( kernel_name: str, all_model_configs: List[ModelConfig], - probe_fn: Callable[[ModelConfig, int], torch.Tensor], - extra_benchmark_config: Dict, + setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], + model_keys: List[str], + forward_fn: Callable[..., torch.Tensor] = _default_forward_fn, + probe_provider: str = "torch", + extra_configs: Optional[Dict] = None, bt: int = 2048, overwrite: bool = False, ) -> Dict: """Build benchmark config dict for model-config sweep. - Returns a single extra_benchmark_config (with a model_configs lookup dict - and seq_len filled in from the sweep), so run_benchmarks iterates over - model names as x_values rather than repeating per-model configs. - Args: kernel_name: Name of the kernel being benchmarked. - all_model_configs: List of model configs to sweep over. - probe_fn: Callable(model_cfg, probe_seq_len) -> output tensor for memory estimation. - extra_benchmark_config: Base config dict. seq_len will be set from sweep result. - bt: Target total tokens (batch_size * seq_len) used as probe bt. - overwrite: Whether to overwrite existing benchmark data. + all_model_configs: List of model configurations to sweep over. + setup_fn: Function that prepares inputs and modules given a + `SingleBenchmarkRunInput`. Returns a tuple of objects consumed + by `forward_fn`. + model_keys: List of attributes to extract from each `ModelConfig` + and include in `extra_benchmark_config`. + forward_fn: Function that executes the kernel given the outputs of + `setup_fn`. Defaults to `(x, layer) -> layer(x)`. + probe_provider: Kernel provider used during memory probing. + extra_configs: Optional static overrides merged into the benchmark config. + bt: Target total tokens (batch_size * seq_len) used to derive sweep. + overwrite: Whether to overwrite existing benchmark results. + + Returns: + A dictionary consumable by `run_benchmarks`. """ def probe_fn_factory(model_cfg, probe_seq_len): - return lambda: probe_fn(model_cfg, probe_seq_len) + def _probe(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider=probe_provider, + extra_benchmark_config=build_extra_config( + model_cfg, + model_keys, + extra_configs=extra_configs, + ), + ) + setup_out = setup_fn(probe_input) + + return forward_fn(*setup_out) + + return _probe sweep = compute_model_config_sweep_config( all_model_configs, @@ -417,14 +475,17 @@ def probe_fn_factory(model_cfg, probe_seq_len): bt=bt, ) - config = {**extra_benchmark_config, "bsz": sweep.batch_size, "seq_len": sweep.seq_len} + base_config = {"bsz": sweep.batch_size, "seq_len": sweep.seq_len} + + if extra_configs: + base_config.update(extra_configs) return { "kernel_name": kernel_name, "x_name": "model_config", "x_label": "model configuration", "x_values": [cfg.name for cfg in sweep.model_configs], - "extra_benchmark_configs": [config], + "extra_benchmark_configs": [base_config], "overwrite": overwrite, } @@ -433,38 +494,70 @@ def build_token_length_sweep( kernel_name: str, probe_seq_len: int, model: ModelConfig, - probe_fn: Callable[[], torch.Tensor], - extra_config_fn: Callable[[SeqLenSweepConfig], Dict] | Dict, - x_values_fn: Callable[[SeqLenSweepConfig], List[int]], + setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], + model_keys: List[str], + extra_configs: Optional[Dict] = None, + forward_fn: Callable[..., torch.Tensor] = _default_forward_fn, + probe_provider: str = "torch", + x_values_fn: Optional[Callable[[SeqLenSweepConfig], List[int]]] = None, overwrite: bool = False, ) -> Dict: """Build benchmark config dict for token-length sweep. Args: kernel_name: Name of the kernel being benchmarked. - model: Model config to use for the sweep. - probe_fn: Callable() -> output tensor for memory estimation. - extra_config_fn: Callable(config) -> dict with normalized keys - that _setup_* expects. - x_values_fn: Callable(config) -> list of sequence lengths to benchmark. - overwrite: Whether to overwrite existing benchmark data. + probe_seq_len: Sequence length used for memory probing. + model: Model configuration used for the sweep. + setup_fn: Function that prepares inputs and modules given a + `SingleBenchmarkRunInput`. Returns a tuple of objects consumed + by `forward_fn`. + model_keys: List of attributes to extract from `model` and include + in `extra_benchmark_config`. + extra_configs: Optional static overrides merged into the config. + forward_fn: Function that executes the kernel given the outputs of + `setup_fn`. Defaults to `(x, layer) -> layer(x)`. + probe_provider: Kernel provider used during memory probing. + x_values_fn: Optional function mapping `SeqLenSweepConfig` to a list + of sequence lengths. Defaults to powers of 2 up to max seq_len. + overwrite: Whether to overwrite existing benchmark results. Returns: - Dict with keys: kernel_name, x_name, x_label, x_values, kernel_providers, - extra_benchmark_configs, overwrite. + A dictionary consumable by `run_benchmarks`. """ + extra_configs = extra_configs or {} + + def probe_fn(): + probe_input = SingleBenchmarkRunInput( + x=probe_seq_len, + kernel_provider=probe_provider, + extra_benchmark_config=build_extra_config( + model, + model_keys, + extra_configs=extra_configs, + ), + ) + setup_out = setup_fn(probe_input) + return forward_fn(*setup_out) + peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn) kernel_bpt = peak_bytes // probe_seq_len config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) + if x_values_fn is None: + x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len)) + 1)] + return { "kernel_name": kernel_name, "x_name": "T", "x_label": "sequence length", "x_values": x_values_fn(config), - "extra_benchmark_configs": [extra_config_fn] - if isinstance(extra_config_fn, dict) - else [extra_config_fn(config)], + "extra_benchmark_configs": [ + build_extra_config( + model, + model_keys, + extra_configs=extra_configs, + ) + ], "overwrite": overwrite, } diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index 8ff88c4d2..fb95bb167 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -1,5 +1,3 @@ -import math - import torch from benchmark_model_configs import MODEL_REGISTRY @@ -74,26 +72,14 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut if args.sweep_mode == "model_config": all_model_configs = list(MODEL_REGISTRY.values()) - def probe_fn(model_cfg, probe_seq_len): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", - extra_benchmark_config={ - "bsz": 1, - "hidden_act": "silu", - "hidden_size": model_cfg.hidden_size, - "intermediate_size": model_cfg.intermediate_size, - "dtype": model_cfg.dtype, - }, - ) - x, layer = _setup_swiglu(probe_input) - return layer(x) - common_configs = build_model_config_sweep( kernel_name="swiglu", all_model_configs=all_model_configs, - probe_fn=probe_fn, - extra_benchmark_config={ + setup_fn=_setup_swiglu, + model_keys=["hidden_size", "intermediate_size", "dtype"], + probe_provider="huggingface", + extra_configs={ + "bsz": 1, "hidden_act": "silu", }, bt=args.bt, @@ -103,40 +89,14 @@ def probe_fn(model_cfg, probe_seq_len): model = get_benchmark_model_config(args.model) probe_seq_len = 1024 - def probe_fn(): - probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, - kernel_provider="huggingface", - extra_benchmark_config={ - "bsz": 1, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, - }, - ) - x, layer = _setup_swiglu(probe_input) - return layer(x) - - def extra_config_fn(config): - return { - "bsz": config.batch_size, - "hidden_size": model.hidden_size, - "intermediate_size": model.intermediate_size, - "hidden_act": "silu", - "dtype": model.dtype, - } - - def x_values_fn(config): - return [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)] - common_configs = build_token_length_sweep( kernel_name="swiglu", probe_seq_len=probe_seq_len, model=model, - probe_fn=probe_fn, - extra_config_fn=extra_config_fn, - x_values_fn=x_values_fn, + setup_fn=_setup_swiglu, + model_keys=["hidden_size", "intermediate_size", "dtype"], + extra_configs={"hidden_act": "silu", "bsz": 1}, + probe_provider="huggingface", overwrite=args.overwrite, ) From ad3072a35c58dd4f8e1ec7653c606ef7ca24505e Mon Sep 17 00:00:00 2001 From: lowdy1 Date: Mon, 27 Apr 2026 03:38:28 +0000 Subject: [PATCH 3/4] add readme instruction --- benchmark/README.md | 83 ++++- benchmark/scripts/benchmark_cpo_loss.py | 305 ++++--------------- benchmark/scripts/benchmark_layer_norm.py | 27 +- benchmark/scripts/benchmark_model_configs.py | 56 ++-- benchmark/scripts/benchmark_swiglu.py | 27 +- benchmark/scripts/utils.py | 27 ++ 6 files changed, 210 insertions(+), 315 deletions(-) diff --git a/benchmark/README.md b/benchmark/README.md index 42f20eb62..fac77ccac 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,6 +1,87 @@ ## Benchmarking Liger Kernels -Follow these steps to benchmark and visualize kernel performance: +### Benchmark Framework Overview + +The benchmarking system is designed to provide a **consistent, low-boilerplate way** to evaluate kernel performance across: + +* Different **model configurations** (e.g., LLaMA and Qwen variants) +* Different **sequence lengths / Batch size * token length** +* Multiple **kernel providers** (e.g., `liger`, `huggingface`) + +#### Core Concepts + +1. `setup_fn` + + Defines how to **construct inputs and modules** for a single forward pass. + + * Input: `SingleBenchmarkRunInput` + * Output: tuple of tensors / modules + + + ```python + def _setup_fn(input: SingleBenchmarkRunInput) -> Tuple[Any, ...]: + x = ... + layer = ... + return x, layer + ``` + +2. Benchmark Function Builders + + Reusable helpers to generate benchmark functions handle: + + * forward / backward / full modes + * timing and memory measurement + + ```python + build_speed_bench_fn(setup_fn) + build_memory_bench_fn(setup_fn) + ``` + +3. Sweep Builders + + (a) `build_model_config_sweep` + + * Sweeps across **model configurations** + * Keeps total tokens (`B * T`) approximately constant + + ```python + common_configs = build_model_config_sweep( + kernel_name=..., + all_model_configs=..., + setup_fn=..., + model_keys=[...], + ) + ``` + + (b) `build_token_length_sweep` + + * Sweeps across **sequence length (T)** + * Keep one static model + + ```python + common_configs = build_token_length_sweep( + kernel_name=..., + probe_x=..., + model=..., + setup_fn=..., + model_keys=[...], + ) + ``` + +4. `model_keys` and `extra_configs` + + * `model_keys`: attributes pulled from `ModelConfig` + + * e.g. `["hidden_size", "dtype"]` + + * `extra_configs`: static overrides + + * e.g. `{"eps": 1e-6}` + + These form `extra_benchmark_config`, passed into `setup_fn`. + + +### Benchmark workflow: 1. Create a benchmark script - Add your script under `benchmark/scripts/` diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index 07501275e..31fb8f734 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -3,17 +3,14 @@ import sys import torch -import triton from benchmark_model_configs import MODEL_REGISTRY -from benchmark_model_configs import compute_model_config_sweep_config -from benchmark_model_configs import compute_seq_len_sweep_config -from benchmark_model_configs import estimate_kernel_peak_memory +from benchmark_model_configs import build_model_config_sweep +from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config -from utils import QUANTILES from utils import SingleBenchmarkRunInput -from utils import SingleBenchmarkRunOutput -from utils import _test_memory +from utils import build_memory_bench_fn +from utils import build_speed_bench_fn from utils import parse_benchmark_script_args from utils import run_benchmarks @@ -24,17 +21,24 @@ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) -def _setup_cpo_loss(input: SingleBenchmarkRunInput): +def setup_cpo_loss(input: SingleBenchmarkRunInput): """Create input tensors and CPO loss from benchmark config.""" from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO cfg = input.extra_benchmark_config - H = cfg["hidden_size"] - V = cfg["vocab_size"] - dtype = cfg["dtype"] - B = input.x T = cfg["T"] + if isinstance(input.x, str): + model_cfg = MODEL_REGISTRY[input.x] + H = model_cfg.hidden_size + V = model_cfg.vocab_size + dtype = model_cfg.dtype + B = cfg["bsz"] + else: + B = input.x + H = cfg["hidden_size"] + V = cfg["vocab_size"] + dtype = cfg["dtype"] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) @@ -46,253 +50,60 @@ def _setup_cpo_loss(input: SingleBenchmarkRunInput): else: raise ValueError(f"Invalid provider: {input.kernel_provider} for CPOLoss") - fwd_fn = lambda: loss_module(_input, target)[0] + fwd_fn = lambda x: loss_module(x, target)[0] return _input, fwd_fn -def bench_speed_cpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd_fn = _setup_cpo_loss(input) - mode = input.kernel_operation_mode - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd_fn, - rep=100, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = fwd_fn() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, - ) - elif mode == "full": - - def full(): - y = fwd_fn() - y.backward() - - ms_50, ms_20, ms_80 = triton.testing.do_bench(full, rep=100, quantiles=QUANTILES) - else: - raise ValueError(f"Unsupported mode: {mode}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - - -def bench_memory_cpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd_fn = _setup_cpo_loss(input) - - def full(): - y = fwd_fn() - y.backward() - - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) - - -def _resolve_model_config_cpo_loss(input: SingleBenchmarkRunInput): - """Resolve model-config-sweep input into standard setup args.""" - cfg = input.extra_benchmark_config - model_info = cfg["model_configs"][input.x] - return _setup_cpo_loss( - SingleBenchmarkRunInput( - x=cfg["B"], - kernel_provider=input.kernel_provider, - extra_benchmark_config={ - "hidden_size": model_info["hidden_size"], - "vocab_size": model_info["vocab_size"], - "dtype": model_info["dtype"], - "T": cfg["T"], - }, - ) - ) - - -def bench_speed_cpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd = _resolve_model_config_cpo_loss(input) - mode = input.kernel_operation_mode - - if mode == "forward": - ms_50, ms_20, ms_80 = triton.testing.do_bench( - fwd, - rep=100, - quantiles=QUANTILES, - ) - elif mode == "backward": - y = fwd() - ms_50, ms_20, ms_80 = triton.testing.do_bench( - lambda: y.backward(retain_graph=True), - grad_to_none=[_input], - rep=100, - quantiles=QUANTILES, - ) - elif mode == "full": - - def full(): - y = fwd() - y.backward() - - ms_50, ms_20, ms_80 = triton.testing.do_bench( - full, - rep=100, - quantiles=QUANTILES, - ) - else: - raise ValueError(f"Unsupported mode: {mode}") - - return SingleBenchmarkRunOutput( - y_20=ms_20, - y_50=ms_50, - y_80=ms_80, - ) - - -def bench_memory_cpo_loss_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - _input, fwd_fn = _resolve_model_config_cpo_loss(input) - - def full(): - y = fwd_fn() - y.backward() - - mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES) - return SingleBenchmarkRunOutput( - y_20=mem_20, - y_50=mem_50, - y_80=mem_80, - ) - - if __name__ == "__main__": args = parse_benchmark_script_args() + T = 1024 if args.sweep_mode == "model_config": all_model_configs = list(MODEL_REGISTRY.values()) - T = 1024 - - def _probe_factory(model_cfg, probe_bt): - def _probe(): - B = max(1, probe_bt // T) - probe_input = SingleBenchmarkRunInput( - x=B, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model_cfg.hidden_size, - "vocab_size": model_cfg.vocab_size, - "dtype": model_cfg.dtype, - "T": T, - }, - ) - _, fwd_fn = _setup_cpo_loss(probe_input) - return fwd_fn() - - return _probe - - sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt) - - model_configs_info = { - cfg.name: { - "hidden_size": cfg.hidden_size, - "vocab_size": cfg.vocab_size, - "dtype": cfg.dtype, - } - for cfg in sweep.model_configs - } - B = max(1, sweep.bt // T) - - common_configs = { - "kernel_name": "fused_linear_cpo_loss", - "x_name": "model_config", - "x_label": "model configuration", - "x_values": [cfg.name for cfg in sweep.model_configs], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "model_configs": model_configs_info, - "B": B, - "T": T, - } - ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_cpo_loss_model_config, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_cpo_loss_model_config, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, + common_configs = build_model_config_sweep( + kernel_name="cpo_loss", + probe_x=1, # batch_size + all_model_configs=all_model_configs, + setup_fn=setup_cpo_loss, + model_keys=["hidden_size", "vocab_size", "dtype"], + extra_configs={"T": T}, + probe_provider="huggingface", + bt=args.bt, + overwrite=args.overwrite, ) else: model = get_benchmark_model_config(args.model) - T = 1024 - probe_bt = 1024 - - def _probe(): - B = probe_bt // T - probe_input = SingleBenchmarkRunInput( - x=B, - kernel_provider="huggingface", - extra_benchmark_config={ - "hidden_size": model.hidden_size, - "vocab_size": model.vocab_size, - "dtype": model.dtype, - "T": T, - }, - ) - _, fwd_fn = _setup_cpo_loss(probe_input) - return fwd_fn() - - peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe) - kernel_bpt = peak_bytes // probe_bt - - config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) - common_configs = { - "kernel_name": "fused_linear_cpo_loss", - "x_name": "B", - "x_label": "Batch Size (B)", - "x_values": [2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1)], - "kernel_providers": ["liger", "huggingface"], - "extra_benchmark_configs": [ - { - "hidden_size": model.hidden_size, - "vocab_size": model.vocab_size, - "dtype": model.dtype, - "T": T, - } + common_configs = build_token_length_sweep( + kernel_name="cpo_loss", + probe_x=1, + model=model, + setup_fn=setup_cpo_loss, + model_keys=["hidden_size", "vocab_size", "dtype"], + extra_configs={"T": T}, + probe_provider="huggingface", + x_values_fn=lambda config: [ + 2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1) ], - "overwrite": args.overwrite, - } - - run_benchmarks( - bench_test_fn=bench_speed_cpo_loss, - kernel_operation_modes=["forward", "backward", "full"], - metric_name="speed", - metric_unit="ms", - **common_configs, - ) - run_benchmarks( - bench_test_fn=bench_memory_cpo_loss, - kernel_operation_modes=["full"], - metric_name="memory", - metric_unit="MB", - **common_configs, + x_name="B", # default x is seq_len, but for CPO loss we want to sweep batch size instead + x_label="Batch Size", + overwrite=args.overwrite, ) + + common_configs["kernel_providers"] = ["liger", "huggingface"] + + run_benchmarks( + bench_test_fn=build_speed_bench_fn(setup_cpo_loss), + kernel_operation_modes=["forward", "backward", "full"], + metric_name="speed", + metric_unit="ms", + **common_configs, + ) + run_benchmarks( + bench_test_fn=build_memory_bench_fn(setup_cpo_loss), + kernel_operation_modes=["full"], + metric_name="memory", + metric_unit="MB", + **common_configs, + ) diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 1686daa7c..77456de32 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -5,11 +5,10 @@ from benchmark_model_configs import build_token_length_sweep from benchmark_model_configs import get_benchmark_model_config from utils import SingleBenchmarkRunInput -from utils import SingleBenchmarkRunOutput +from utils import build_memory_bench_fn +from utils import build_speed_bench_fn from utils import parse_benchmark_script_args from utils import run_benchmarks -from utils import run_memory_benchmark -from utils import run_speed_benchmark from liger_kernel.transformers.layer_norm import LigerLayerNorm from liger_kernel.utils import infer_device @@ -17,7 +16,7 @@ device = infer_device() -def _setup_layer_norm(input: SingleBenchmarkRunInput): +def setup_layer_norm(input: SingleBenchmarkRunInput): """Create input tensor and LayerNorm layer from benchmark config.""" cfg = input.extra_benchmark_config if isinstance(input.x, str): @@ -47,16 +46,6 @@ def _setup_layer_norm(input: SingleBenchmarkRunInput): return x, layer -def bench_speed_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_layer_norm(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_layer_norm(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - if __name__ == "__main__": args = parse_benchmark_script_args() @@ -66,7 +55,7 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu common_configs = build_model_config_sweep( kernel_name="layer_norm", all_model_configs=all_model_configs, - setup_fn=_setup_layer_norm, + setup_fn=setup_layer_norm, model_keys=["hidden_size", "dtype"], extra_configs={ "eps": 1e-6, @@ -82,9 +71,9 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu common_configs = build_token_length_sweep( kernel_name="layer_norm", - probe_seq_len=probe_seq_len, + probe_x=probe_seq_len, model=model, - setup_fn=_setup_layer_norm, + setup_fn=setup_layer_norm, model_keys=["hidden_size", "dtype"], extra_configs={ "eps": 1e-6, @@ -96,14 +85,14 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu common_configs["kernel_providers"] = ["liger", "huggingface"] run_benchmarks( - bench_test_fn=bench_speed_layer_norm, + bench_test_fn=build_speed_bench_fn(setup_layer_norm), kernel_operation_modes=["full", "forward", "backward"], metric_name="speed", metric_unit="ms", **common_configs, ) run_benchmarks( - bench_test_fn=bench_memory_layer_norm, + bench_test_fn=build_memory_bench_fn(setup_layer_norm), kernel_operation_modes=["full", "forward", "backward"], metric_name="memory", metric_unit="MB", diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index 6f96a7303..52a5387cf 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -36,6 +36,7 @@ import torch from utils import SingleBenchmarkRunInput +from utils import default_forward_fn from liger_kernel.utils import get_total_gpu_memory from liger_kernel.utils import infer_device @@ -354,7 +355,7 @@ def compute_model_config_sweep_config( Args: model_configs: Model configs to benchmark. - probe_fn_factory: Factory ``(model_cfg, probe_seq_len) -> probe_fn``. + probe_fn_factory: Factory ``(model_cfg) -> probe_fn``. The returned probe_fn should perform setup + forward pass and return a tensor suitable for ``.backward()``, same contract as :func:`estimate_kernel_peak_memory`'s *probe_fn*. @@ -368,7 +369,7 @@ def compute_model_config_sweep_config( max_bytes_per_token = 0 for model_cfg in model_configs: - probe_fn = probe_fn_factory(model_cfg, probe_seq_len) + probe_fn = probe_fn_factory(model_cfg) peak_bytes = estimate_kernel_peak_memory(probe_fn) bpt = max(1, peak_bytes // probe_seq_len) max_bytes_per_token = max(max_bytes_per_token, bpt) @@ -408,27 +409,16 @@ def build_extra_config( return cfg -def _default_forward_fn(x, layer): - """Default forward function for common (input, module) patterns. - - Assumes `setup_fn` returns `(x, layer)` and simply applies: - layer(x) - - This covers the majority of kernels that follow a - "tensor + nn.Module" execution pattern. - """ - return layer(x) - - def build_model_config_sweep( kernel_name: str, all_model_configs: List[ModelConfig], setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], model_keys: List[str], - forward_fn: Callable[..., torch.Tensor] = _default_forward_fn, - probe_provider: str = "torch", + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, + probe_provider: str = "huggingface", extra_configs: Optional[Dict] = None, bt: int = 2048, + probe_x: int = None, overwrite: bool = False, ) -> Dict: """Build benchmark config dict for model-config sweep. @@ -445,17 +435,20 @@ def build_model_config_sweep( `setup_fn`. Defaults to `(x, layer) -> layer(x)`. probe_provider: Kernel provider used during memory probing. extra_configs: Optional static overrides merged into the benchmark config. + token_length: Optional token length used for memory probing and sweep config. bt: Target total tokens (batch_size * seq_len) used to derive sweep. + probe_x: Value of x passed to setup_fn during probing. This should be + specified if the kernel's input.x is not T. overwrite: Whether to overwrite existing benchmark results. Returns: A dictionary consumable by `run_benchmarks`. """ - def probe_fn_factory(model_cfg, probe_seq_len): + def probe_fn_factory(model_cfg): def _probe(): probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, + x=probe_x if probe_x is not None else bt, kernel_provider=probe_provider, extra_benchmark_config=build_extra_config( model_cfg, @@ -492,13 +485,16 @@ def _probe(): def build_token_length_sweep( kernel_name: str, - probe_seq_len: int, + probe_x: int, model: ModelConfig, setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], model_keys: List[str], extra_configs: Optional[Dict] = None, - forward_fn: Callable[..., torch.Tensor] = _default_forward_fn, - probe_provider: str = "torch", + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, + probe_provider: str = "huggingface", + probe_bt: Optional[int] = 1024, + x_name: str = "T", + x_label: str = "sequence length", x_values_fn: Optional[Callable[[SeqLenSweepConfig], List[int]]] = None, overwrite: bool = False, ) -> Dict: @@ -506,7 +502,7 @@ def build_token_length_sweep( Args: kernel_name: Name of the kernel being benchmarked. - probe_seq_len: Sequence length used for memory probing. + probe_x: Value of x passed to setup_fn during probing. model: Model configuration used for the sweep. setup_fn: Function that prepares inputs and modules given a `SingleBenchmarkRunInput`. Returns a tuple of objects consumed @@ -517,8 +513,11 @@ def build_token_length_sweep( forward_fn: Function that executes the kernel given the outputs of `setup_fn`. Defaults to `(x, layer) -> layer(x)`. probe_provider: Kernel provider used during memory probing. + probe_bt: Target total tokens (batch_size * seq_len) used to derive sweep. + x_name: Name of the x-axis variable (e.g. "T" or "B"). + x_label: Label for the x-axis (e.g. "sequence length" or "batch size"). x_values_fn: Optional function mapping `SeqLenSweepConfig` to a list - of sequence lengths. Defaults to powers of 2 up to max seq_len. + of x values. Defaults to powers of 2 up to max seq_len. overwrite: Whether to overwrite existing benchmark results. Returns: @@ -528,7 +527,7 @@ def build_token_length_sweep( def probe_fn(): probe_input = SingleBenchmarkRunInput( - x=probe_seq_len, + x=probe_x, kernel_provider=probe_provider, extra_benchmark_config=build_extra_config( model, @@ -540,17 +539,16 @@ def probe_fn(): return forward_fn(*setup_out) peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn) - kernel_bpt = peak_bytes // probe_seq_len + kernel_bpt = max(1, peak_bytes // max(probe_x, probe_bt)) config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) - if x_values_fn is None: - x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len)) + 1)] + x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len * cfg.batch_size)) + 1)] return { "kernel_name": kernel_name, - "x_name": "T", - "x_label": "sequence length", + "x_name": x_name, + "x_label": x_label, "x_values": x_values_fn(config), "extra_benchmark_configs": [ build_extra_config( diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index fb95bb167..fc0f42979 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -7,11 +7,10 @@ from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP from utils import SingleBenchmarkRunInput -from utils import SingleBenchmarkRunOutput +from utils import build_memory_bench_fn +from utils import build_speed_bench_fn from utils import parse_benchmark_script_args from utils import run_benchmarks -from utils import run_memory_benchmark -from utils import run_speed_benchmark from liger_kernel.transformers.swiglu import LigerSwiGLUMLP from liger_kernel.utils import infer_device @@ -19,7 +18,7 @@ device = infer_device() -def _setup_swiglu(input: SingleBenchmarkRunInput): +def setup_swiglu(input: SingleBenchmarkRunInput): """Create input tensor and SwiGLU layer from benchmark config.""" cfg = input.extra_benchmark_config if isinstance(input.x, str): @@ -56,16 +55,6 @@ def _setup_swiglu(input: SingleBenchmarkRunInput): return x, layer -def bench_speed_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_swiglu(input) - return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x]) - - -def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - x, layer = _setup_swiglu(input) - return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode) - - if __name__ == "__main__": args = parse_benchmark_script_args() @@ -75,7 +64,7 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut common_configs = build_model_config_sweep( kernel_name="swiglu", all_model_configs=all_model_configs, - setup_fn=_setup_swiglu, + setup_fn=setup_swiglu, model_keys=["hidden_size", "intermediate_size", "dtype"], probe_provider="huggingface", extra_configs={ @@ -91,9 +80,9 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut common_configs = build_token_length_sweep( kernel_name="swiglu", - probe_seq_len=probe_seq_len, + probe_x=probe_seq_len, model=model, - setup_fn=_setup_swiglu, + setup_fn=setup_swiglu, model_keys=["hidden_size", "intermediate_size", "dtype"], extra_configs={"hidden_act": "silu", "bsz": 1}, probe_provider="huggingface", @@ -103,14 +92,14 @@ def bench_memory_swiglu(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOut common_configs["kernel_providers"] = ["liger", "huggingface"] run_benchmarks( - bench_test_fn=bench_speed_swiglu, + bench_test_fn=build_speed_bench_fn(setup_swiglu), kernel_operation_modes=["full", "forward", "backward"], metric_name="speed", metric_unit="ms", **common_configs, ) run_benchmarks( - bench_test_fn=bench_memory_swiglu, + bench_test_fn=build_memory_bench_fn(setup_swiglu), kernel_operation_modes=["full", "forward", "backward"], metric_name="memory", metric_unit="MB", diff --git a/benchmark/scripts/utils.py b/benchmark/scripts/utils.py index 0cb307d19..f3774a133 100644 --- a/benchmark/scripts/utils.py +++ b/benchmark/scripts/utils.py @@ -191,6 +191,33 @@ def full(): return SingleBenchmarkRunOutput(y_20=mem_20, y_50=mem_50, y_80=mem_80) +def default_forward_fn(*setup_out): + x, layer = setup_out[0], setup_out[1] + return layer(x) + + +def build_speed_bench_fn( + setup_fn: Callable[["SingleBenchmarkRunInput"], Any], + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, +) -> Callable: + def bench_speed(input: "SingleBenchmarkRunInput") -> SingleBenchmarkRunOutput: + setup_out = setup_fn(input) + return run_speed_benchmark(lambda: forward_fn(*setup_out), input.kernel_operation_mode, [setup_out[0]]) + + return bench_speed + + +def build_memory_bench_fn( + setup_fn: Callable[["SingleBenchmarkRunInput"], Any], + forward_fn: Callable[..., torch.Tensor] = default_forward_fn, +) -> Callable: + def bench_memory(input: "SingleBenchmarkRunInput") -> SingleBenchmarkRunOutput: + setup_out = setup_fn(input) + return run_memory_benchmark(lambda: forward_fn(*setup_out), input.kernel_operation_mode) + + return bench_memory + + def get_current_file_directory() -> str: """ Returns the directory path of the current Python file. From f7bf977f9d379d3ef06d6469bb8a195d90848496 Mon Sep 17 00:00:00 2001 From: lowdy1 Date: Wed, 29 Apr 2026 06:33:53 +0000 Subject: [PATCH 4/4] add probe_dim and scale_dim --- benchmark/README.md | 16 ++++++-- benchmark/scripts/benchmark_cpo_loss.py | 5 +-- benchmark/scripts/benchmark_layer_norm.py | 2 + benchmark/scripts/benchmark_model_configs.py | 41 +++++++++++++++----- benchmark/scripts/benchmark_swiglu.py | 2 + 5 files changed, 51 insertions(+), 15 deletions(-) diff --git a/benchmark/README.md b/benchmark/README.md index fac77ccac..9597d94e3 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -41,8 +41,10 @@ The benchmarking system is designed to provide a **consistent, low-boilerplate w (a) `build_model_config_sweep` - * Sweeps across **model configurations** + * Sweeps across **model configurations**(e.g. hidden size, dtype, vocab size) * Keeps total tokens (`B * T`) approximately constant + * Automatically derives a suitable `(B, T)` that will not cause OOM under the given token budget + * `probe_dim` must align with how `input.x` is interpreted in `setup_fn` ```python common_configs = build_model_config_sweep( @@ -50,13 +52,20 @@ The benchmarking system is designed to provide a **consistent, low-boilerplate w all_model_configs=..., setup_fn=..., model_keys=[...], + probe_dim: Literal["T", "B", "BT"] = "T" ) ``` (b) `build_token_length_sweep` - * Sweeps across **sequence length (T)** - * Keep one static model + * Sweeps along a **chosen scaling dimension**: + + * `"T"` → sequence length + * `"B"` → batch size + * `"BT"` → total tokens + * Uses a **single fixed model configuration** + * Maintains a consistent memory model via bytes-per-token estimation + * `scale_dim` must align with how `input.x` is interpreted in `setup_fn` ```python common_configs = build_token_length_sweep( @@ -65,6 +74,7 @@ The benchmarking system is designed to provide a **consistent, low-boilerplate w model=..., setup_fn=..., model_keys=[...], + scale_dim: Literal["T", "B", "BT"] = "T", ) ``` diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index 31fb8f734..d2ecaf43f 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -63,11 +63,11 @@ def setup_cpo_loss(input: SingleBenchmarkRunInput): common_configs = build_model_config_sweep( kernel_name="cpo_loss", - probe_x=1, # batch_size all_model_configs=all_model_configs, setup_fn=setup_cpo_loss, model_keys=["hidden_size", "vocab_size", "dtype"], extra_configs={"T": T}, + probe_dim="B", probe_provider="huggingface", bt=args.bt, overwrite=args.overwrite, @@ -82,12 +82,11 @@ def setup_cpo_loss(input: SingleBenchmarkRunInput): setup_fn=setup_cpo_loss, model_keys=["hidden_size", "vocab_size", "dtype"], extra_configs={"T": T}, + scale_dim="B", probe_provider="huggingface", x_values_fn=lambda config: [ 2**i for i in range(1, int(math.log2(max(2, config.batch_size * config.seq_len // T))) + 1) ], - x_name="B", # default x is seq_len, but for CPO loss we want to sweep batch size instead - x_label="Batch Size", overwrite=args.overwrite, ) diff --git a/benchmark/scripts/benchmark_layer_norm.py b/benchmark/scripts/benchmark_layer_norm.py index 77456de32..ce6032fd0 100644 --- a/benchmark/scripts/benchmark_layer_norm.py +++ b/benchmark/scripts/benchmark_layer_norm.py @@ -60,6 +60,7 @@ def setup_layer_norm(input: SingleBenchmarkRunInput): extra_configs={ "eps": 1e-6, }, + probe_dim="BT", probe_provider="huggingface", bt=args.bt, overwrite=args.overwrite, @@ -78,6 +79,7 @@ def setup_layer_norm(input: SingleBenchmarkRunInput): extra_configs={ "eps": 1e-6, }, + scale_dim="BT", probe_provider="huggingface", overwrite=args.overwrite, ) diff --git a/benchmark/scripts/benchmark_model_configs.py b/benchmark/scripts/benchmark_model_configs.py index 52a5387cf..4b56f592f 100644 --- a/benchmark/scripts/benchmark_model_configs.py +++ b/benchmark/scripts/benchmark_model_configs.py @@ -30,6 +30,7 @@ from typing import Callable from typing import Dict from typing import List +from typing import Literal from typing import Optional from typing import Tuple @@ -414,11 +415,11 @@ def build_model_config_sweep( all_model_configs: List[ModelConfig], setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], model_keys: List[str], + probe_dim: Literal["T", "B", "BT"] = "T", forward_fn: Callable[..., torch.Tensor] = default_forward_fn, probe_provider: str = "huggingface", extra_configs: Optional[Dict] = None, bt: int = 2048, - probe_x: int = None, overwrite: bool = False, ) -> Dict: """Build benchmark config dict for model-config sweep. @@ -448,7 +449,7 @@ def build_model_config_sweep( def probe_fn_factory(model_cfg): def _probe(): probe_input = SingleBenchmarkRunInput( - x=probe_x if probe_x is not None else bt, + x=1 if probe_dim == "B" else bt, kernel_provider=probe_provider, extra_benchmark_config=build_extra_config( model_cfg, @@ -490,10 +491,9 @@ def build_token_length_sweep( setup_fn: Callable[[SingleBenchmarkRunInput], Tuple[Any, ...]], model_keys: List[str], extra_configs: Optional[Dict] = None, + scale_dim: Literal["T", "B", "BT"] = "T", forward_fn: Callable[..., torch.Tensor] = default_forward_fn, probe_provider: str = "huggingface", - probe_bt: Optional[int] = 1024, - x_name: str = "T", x_label: str = "sequence length", x_values_fn: Optional[Callable[[SeqLenSweepConfig], List[int]]] = None, overwrite: bool = False, @@ -513,8 +513,7 @@ def build_token_length_sweep( forward_fn: Function that executes the kernel given the outputs of `setup_fn`. Defaults to `(x, layer) -> layer(x)`. probe_provider: Kernel provider used during memory probing. - probe_bt: Target total tokens (batch_size * seq_len) used to derive sweep. - x_name: Name of the x-axis variable (e.g. "T" or "B"). + scale_dim: Dimension along which to scale the sweep (e.g. "T", "B", or "BT"). x_label: Label for the x-axis (e.g. "sequence length" or "batch size"). x_values_fn: Optional function mapping `SeqLenSweepConfig` to a list of x values. Defaults to powers of 2 up to max seq_len. @@ -539,15 +538,39 @@ def probe_fn(): return forward_fn(*setup_out) peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn) - kernel_bpt = max(1, peak_bytes // max(probe_x, probe_bt)) + # --------------------------------------- + # derive total tokens (BT) based on scale_dim + # --------------------------------------- + if scale_dim == "T": + B = extra_configs.get("B", 1) + probe_bt = probe_x * B + + elif scale_dim == "B": + T = extra_configs.get("T") + if T is None: + raise ValueError("For B sweep, extra_configs['T'] must be provided") + probe_bt = probe_x * T + + elif scale_dim == "BT": + probe_bt = probe_x + + else: + raise ValueError(f"Unsupported scale_dim: {scale_dim}") + + kernel_bpt = max(1, peak_bytes // probe_bt) config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt) if x_values_fn is None: - x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len * cfg.batch_size)) + 1)] + if scale_dim == "T": + x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len)) + 1)] + elif scale_dim == "B": + x_values_fn = lambda cfg: [2**i for i in range(0, int(math.log2(cfg.batch_size)) + 1)] + elif scale_dim == "BT": + x_values_fn = lambda cfg: [2**i for i in range(10, int(math.log2(cfg.seq_len * cfg.batch_size)) + 1)] return { "kernel_name": kernel_name, - "x_name": x_name, + "x_name": scale_dim, "x_label": x_label, "x_values": x_values_fn(config), "extra_benchmark_configs": [ diff --git a/benchmark/scripts/benchmark_swiglu.py b/benchmark/scripts/benchmark_swiglu.py index fc0f42979..200bbf405 100644 --- a/benchmark/scripts/benchmark_swiglu.py +++ b/benchmark/scripts/benchmark_swiglu.py @@ -71,6 +71,7 @@ def setup_swiglu(input: SingleBenchmarkRunInput): "bsz": 1, "hidden_act": "silu", }, + probe_dim="BT", bt=args.bt, overwrite=args.overwrite, ) @@ -85,6 +86,7 @@ def setup_swiglu(input: SingleBenchmarkRunInput): setup_fn=setup_swiglu, model_keys=["hidden_size", "intermediate_size", "dtype"], extra_configs={"hidden_act": "silu", "bsz": 1}, + scale_dim="BT", probe_provider="huggingface", overwrite=args.overwrite, )