|
| 1 | +import os |
| 2 | +import sys |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | + |
| 7 | +from utils import QUANTILES |
| 8 | +from utils import SingleBenchmarkRunInput |
| 9 | +from utils import SingleBenchmarkRunOutput |
| 10 | +from utils import _test_memory |
| 11 | +from utils import parse_benchmark_script_args |
| 12 | +from utils import run_benchmarks |
| 13 | + |
| 14 | +from liger_kernel.utils import infer_device |
| 15 | + |
| 16 | +device = infer_device() |
| 17 | + |
| 18 | +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) |
| 19 | + |
| 20 | + |
| 21 | +def bench_speed_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 22 | + from test.transformers.test_dyt import LigerDyT |
| 23 | + from test.transformers.test_dyt import TorchDyT |
| 24 | + |
| 25 | + hidden_size = input.x |
| 26 | + provider = input.kernel_provider |
| 27 | + mode = input.kernel_operation_mode |
| 28 | + extra_benchmark_config = input.extra_benchmark_config |
| 29 | + BT = extra_benchmark_config["BT"] |
| 30 | + dtype = extra_benchmark_config["dtype"] |
| 31 | + |
| 32 | + x_shape = (BT, hidden_size) |
| 33 | + torch_y = lambda x: TorchDyT(hidden_size=hidden_size).to(device)(x) |
| 34 | + torch_compile_y = lambda x: torch.compile( |
| 35 | + TorchDyT(hidden_size=hidden_size).to(device) |
| 36 | + )(x) |
| 37 | + triton_y = lambda x: LigerDyT(hidden_size=hidden_size).to(device)(x) |
| 38 | + |
| 39 | + x = torch.randn(x_shape, dtype=dtype, device=device) |
| 40 | + dy = torch.randn_like(x) |
| 41 | + x.requires_grad_(True) |
| 42 | + |
| 43 | + def fwd(): |
| 44 | + if provider == "liger": |
| 45 | + return triton_y(x) |
| 46 | + elif provider == "torch": |
| 47 | + return torch_y(x) |
| 48 | + elif provider == "torch_compile": |
| 49 | + return torch_compile_y(x) |
| 50 | + |
| 51 | + if mode == "forward": |
| 52 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 53 | + fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500 |
| 54 | + ) |
| 55 | + elif mode == "backward": |
| 56 | + y = fwd() |
| 57 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 58 | + lambda: y.backward(dy, retain_graph=True), |
| 59 | + quantiles=QUANTILES, |
| 60 | + grad_to_none=[x], |
| 61 | + rep=500, |
| 62 | + ) |
| 63 | + elif mode == "full": |
| 64 | + |
| 65 | + def full(): |
| 66 | + y = fwd() |
| 67 | + y.backward(dy) |
| 68 | + |
| 69 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 70 | + full, quantiles=QUANTILES, grad_to_none=[x], rep=500 |
| 71 | + ) |
| 72 | + |
| 73 | + return SingleBenchmarkRunOutput( |
| 74 | + y_20=ms_20, |
| 75 | + y_50=ms_50, |
| 76 | + y_80=ms_80, |
| 77 | + ) |
| 78 | + |
| 79 | + |
| 80 | +def bench_memory_dyt(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 81 | + from test.transformers.test_dyt import LigerDyT |
| 82 | + from test.transformers.test_dyt import TorchDyT |
| 83 | + |
| 84 | + hidden_size = input.x |
| 85 | + provider = input.kernel_provider |
| 86 | + extra_benchmark_config = input.extra_benchmark_config |
| 87 | + BT = extra_benchmark_config["BT"] |
| 88 | + dtype = extra_benchmark_config["dtype"] |
| 89 | + |
| 90 | + x_shape = (BT, hidden_size) |
| 91 | + torch_y = lambda x: TorchDyT(hidden_size=hidden_size).to(device)(x) |
| 92 | + torch_compile_y = lambda x: torch.compile( |
| 93 | + TorchDyT(hidden_size=hidden_size).to(device) |
| 94 | + )(x) |
| 95 | + triton_y = lambda x: LigerDyT(hidden_size=hidden_size).to(device)(x) |
| 96 | + |
| 97 | + x = torch.randn(x_shape, dtype=dtype, device=device) |
| 98 | + dy = torch.randn_like(x) |
| 99 | + x.requires_grad_(True) |
| 100 | + |
| 101 | + def fwd(): |
| 102 | + if provider == "liger": |
| 103 | + return triton_y(x) |
| 104 | + elif provider == "torch": |
| 105 | + return torch_y(x) |
| 106 | + elif provider == "torch_compile": |
| 107 | + return torch_compile_y(x) |
| 108 | + |
| 109 | + def full(): |
| 110 | + y = fwd() |
| 111 | + y.backward(dy, retain_graph=True) |
| 112 | + |
| 113 | + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) |
| 114 | + return SingleBenchmarkRunOutput( |
| 115 | + y_20=mem_20, |
| 116 | + y_50=mem_50, |
| 117 | + y_80=mem_80, |
| 118 | + ) |
| 119 | + |
| 120 | + |
| 121 | +if __name__ == "__main__": |
| 122 | + args = parse_benchmark_script_args() |
| 123 | + |
| 124 | + common_configs = { |
| 125 | + "kernel_name": "dyt", |
| 126 | + "x_name": "hidden_size", |
| 127 | + "x_label": "hidden size", |
| 128 | + "x_values": [2**i for i in range(10, 15)], |
| 129 | + "kernel_providers": ["liger", "torch", "torch_compile"], |
| 130 | + "extra_benchmark_configs": [{"BT": 4096, "dtype": torch.float32}], |
| 131 | + "overwrite": args.overwrite, |
| 132 | + } |
| 133 | + |
| 134 | + run_benchmarks( |
| 135 | + bench_test_fn=bench_speed_dyt, |
| 136 | + kernel_operation_modes=["forward", "backward", "full"], |
| 137 | + metric_name="speed", |
| 138 | + metric_unit="ms", |
| 139 | + **common_configs, |
| 140 | + ) |
| 141 | + run_benchmarks( |
| 142 | + bench_test_fn=bench_memory_dyt, |
| 143 | + kernel_operation_modes=["full"], |
| 144 | + metric_name="memory", |
| 145 | + metric_unit="MB", |
| 146 | + **common_configs, |
| 147 | + ) |
0 commit comments