Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 70 additions & 123 deletions benchmark/scripts/benchmark_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
import torch

from benchmark_model_configs import MODEL_REGISTRY
from benchmark_model_configs import compute_model_config_sweep_config
from benchmark_model_configs import compute_seq_len_sweep_config
from benchmark_model_configs import estimate_kernel_peak_memory
from benchmark_model_configs import build_model_config_sweep
from benchmark_model_configs import build_token_length_sweep
from benchmark_model_configs import get_benchmark_model_config
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
Expand All @@ -23,19 +22,28 @@
def _setup_layer_norm(input: SingleBenchmarkRunInput):
"""Create input tensor and LayerNorm layer from benchmark config."""
cfg = input.extra_benchmark_config
hidden_size = cfg["hidden_size"]
if isinstance(input.x, str):
model_cfg = MODEL_REGISTRY[input.x]
seq_len = cfg["seq_len"]
hidden_size = model_cfg.hidden_size
dtype = model_cfg.dtype
else:
seq_len = input.x
hidden_size = cfg["hidden_size"]
dtype = cfg["dtype"]

eps = cfg["eps"]
x = torch.randn(
input.x,
seq_len,
hidden_size,
device=device,
dtype=cfg["dtype"],
dtype=dtype,
requires_grad=True,
)
if input.kernel_provider == "liger":
layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device)
layer = LigerLayerNorm(hidden_size=hidden_size, eps=eps).to(device).to(dtype)
elif input.kernel_provider == "huggingface":
layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device)
layer = torch.nn.LayerNorm(hidden_size, eps=eps).to(device).to(dtype)
else:
raise ValueError(f"Invalid provider: {input.kernel_provider} for LayerNorm")
return x, layer
Expand All @@ -51,100 +59,41 @@ def bench_memory_layer_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRu
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)


def _resolve_model_config_layer_norm(input: SingleBenchmarkRunInput):
"""Resolve model-config-sweep input into standard setup args."""
cfg = input.extra_benchmark_config
model_info = cfg["model_configs"][input.x]
return _setup_layer_norm(
SingleBenchmarkRunInput(
x=cfg["BT"],
kernel_provider=input.kernel_provider,
extra_benchmark_config={
"hidden_size": model_info["hidden_size"],
"dtype": model_info["dtype"],
"eps": cfg["eps"],
},
)
)


def bench_speed_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _resolve_model_config_layer_norm(input)
return run_speed_benchmark(lambda: layer(x), input.kernel_operation_mode, [x])


def bench_memory_layer_norm_model_config(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
x, layer = _resolve_model_config_layer_norm(input)
return run_memory_benchmark(lambda: layer(x), input.kernel_operation_mode)


if __name__ == "__main__":
args = parse_benchmark_script_args()

if args.sweep_mode == "model_config":
all_model_configs = list(MODEL_REGISTRY.values())

def _probe_factory(model_cfg, probe_bt):
def _probe():
probe_input = SingleBenchmarkRunInput(
x=probe_bt,
kernel_provider="huggingface",
extra_benchmark_config={
"hidden_size": model_cfg.hidden_size,
"dtype": model_cfg.dtype,
"eps": 1e-6,
},
)
x, layer = _setup_layer_norm(probe_input)
return layer(x)

return _probe

sweep = compute_model_config_sweep_config(all_model_configs, probe_fn_factory=_probe_factory, bt=args.bt)

model_configs_info = {
cfg.name: {
"hidden_size": cfg.hidden_size,
"dtype": cfg.dtype,
}
for cfg in sweep.model_configs
}

common_configs = {
"kernel_name": "layer_norm",
"x_name": "model_config",
"x_label": "model configuration",
"x_values": [cfg.name for cfg in sweep.model_configs],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"model_configs": model_configs_info,
"BT": sweep.bt,
def probe_fn(model_cfg, probe_bt):
probe_input = SingleBenchmarkRunInput(
x=probe_bt,
kernel_provider="huggingface",
extra_benchmark_config={
"hidden_size": model_cfg.hidden_size,
"dtype": model_cfg.dtype,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_layer_norm_model_config,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_layer_norm_model_config,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
},
)
x, layer = _setup_layer_norm(probe_input)
return layer(x)

common_configs = build_model_config_sweep(
kernel_name="layer_norm",
all_model_configs=all_model_configs,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: drop this argument as we always pass MODEL_REGISTRY.value

probe_fn=probe_fn,
extra_benchmark_config={
"eps": 1e-6,
},
bt=args.bt,
overwrite=args.overwrite,
)

else:
model = get_benchmark_model_config(args.model)
probe_bt = 1024

def _probe():
def probe_fn():
probe_input = SingleBenchmarkRunInput(
x=probe_bt,
kernel_provider="huggingface",
Expand All @@ -157,38 +106,36 @@ def _probe():
x, layer = _setup_layer_norm(probe_input)
return layer(x)

peak_bytes = estimate_kernel_peak_memory(probe_fn=_probe)
kernel_bpt = peak_bytes // probe_bt
def x_values_fn(config):
return [2**i for i in range(10, int(math.log2(config.seq_len)) + 1)]

common_configs = build_token_length_sweep(
kernel_name="layer_norm",
probe_seq_len=probe_bt,
model=model,
probe_fn=probe_fn,
extra_config_fn={
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
},
x_values_fn=x_values_fn,
overwrite=args.overwrite,
)

config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)
common_configs["kernel_providers"] = ["liger", "huggingface"]

common_configs = {
"kernel_name": "layer_norm",
"x_name": "BT",
"x_label": "B * T",
"x_values": [2**i for i in range(10, int(math.log2(config.batch_size * config.seq_len)) + 1)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"hidden_size": model.hidden_size,
"dtype": model.dtype,
"eps": 1e-6,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_layer_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_layer_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_speed_layer_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="speed",
metric_unit="ms",
**common_configs,
)
run_benchmarks(
bench_test_fn=bench_memory_layer_norm,
kernel_operation_modes=["full", "forward", "backward"],
metric_name="memory",
metric_unit="MB",
**common_configs,
)
87 changes: 86 additions & 1 deletion benchmark/scripts/benchmark_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ def compute_seq_len_sweep_config(
def compute_model_config_sweep_config(
model_configs: List[ModelConfig],
probe_fn_factory: Callable[[ModelConfig, int], Callable[[], torch.Tensor]],
bt: int = 2048,
bt: int = 1024,
memory_utilization: float = 0.4,
) -> ModelConfigSweepConfig:
"""Find safe (batch_size, seq_len) that works across all model configs.
Expand Down Expand Up @@ -383,3 +383,88 @@ def compute_model_config_sweep_config(
batch_size=batch_size,
seq_len=seq_len,
)


def build_model_config_sweep(
kernel_name: str,
all_model_configs: List[ModelConfig],
probe_fn: Callable[[ModelConfig, int], torch.Tensor],
extra_benchmark_config: Dict,
bt: int = 2048,
overwrite: bool = False,
) -> Dict:
"""Build benchmark config dict for model-config sweep.

Returns a single extra_benchmark_config (with a model_configs lookup dict
and seq_len filled in from the sweep), so run_benchmarks iterates over
model names as x_values rather than repeating per-model configs.

Args:
kernel_name: Name of the kernel being benchmarked.
all_model_configs: List of model configs to sweep over.
probe_fn: Callable(model_cfg, probe_seq_len) -> output tensor for memory estimation.
extra_benchmark_config: Base config dict. seq_len will be set from sweep result.
bt: Target total tokens (batch_size * seq_len) used as probe bt.
overwrite: Whether to overwrite existing benchmark data.
"""

def probe_fn_factory(model_cfg, probe_seq_len):
return lambda: probe_fn(model_cfg, probe_seq_len)

sweep = compute_model_config_sweep_config(
all_model_configs,
probe_fn_factory=probe_fn_factory,
bt=bt,
)

config = {**extra_benchmark_config, "bsz": sweep.batch_size, "seq_len": sweep.seq_len}

return {
"kernel_name": kernel_name,
"x_name": "model_config",
"x_label": "model configuration",
"x_values": [cfg.name for cfg in sweep.model_configs],
"extra_benchmark_configs": [config],
"overwrite": overwrite,
}


def build_token_length_sweep(
kernel_name: str,
probe_seq_len: int,
model: ModelConfig,
probe_fn: Callable[[], torch.Tensor],
extra_config_fn: Callable[[SeqLenSweepConfig], Dict] | Dict,
x_values_fn: Callable[[SeqLenSweepConfig], List[int]],
overwrite: bool = False,
) -> Dict:
"""Build benchmark config dict for token-length sweep.

Args:
kernel_name: Name of the kernel being benchmarked.
model: Model config to use for the sweep.
probe_fn: Callable() -> output tensor for memory estimation.
extra_config_fn: Callable(config) -> dict with normalized keys
that _setup_* expects.
x_values_fn: Callable(config) -> list of sequence lengths to benchmark.
overwrite: Whether to overwrite existing benchmark data.

Returns:
Dict with keys: kernel_name, x_name, x_label, x_values, kernel_providers,
extra_benchmark_configs, overwrite.
"""
peak_bytes = estimate_kernel_peak_memory(probe_fn=probe_fn)
kernel_bpt = peak_bytes // probe_seq_len

config = compute_seq_len_sweep_config(model, kernel_bytes_per_token=kernel_bpt)

return {
"kernel_name": kernel_name,
"x_name": "T",
"x_label": "sequence length",
"x_values": x_values_fn(config),
"extra_benchmark_configs": [extra_config_fn]
if isinstance(extra_config_fn, dict)
else [extra_config_fn(config)],
"overwrite": overwrite,
}
Loading