Skip to content

Commit 8369268

Browse files
authored
Generate speedup for inference (#2151)
1 parent e5d9a97 commit 8369268

File tree

5 files changed

+73
-19
lines changed

5 files changed

+73
-19
lines changed

benchmarks/microbenchmarks/README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,18 @@ Currently, quantization string is in same format as the one being passed in llam
130130
max_power: 11
131131
```
132132

133+
- `small_sweep`: Generate a small sweep of shapes with increasing powers of 2 for M, K, N dimensions
134+
- Parameters:
135+
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
136+
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
137+
- Note: This generates shapes where M <= K <= N (ensuring increasing order), which produces fewer combinations than the full sweep, and could be good to use for plots like heatmap
138+
```yaml
139+
matrix_shapes:
140+
- name: "small_sweep"
141+
min_power: 10 # 2^10 = 1024
142+
max_power: 15 # 2^15 = 32,768
143+
```
144+
133145
- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
134146
- Parameters:
135147
- `min_power`: Minimum power of 2 (default: 8, which is 256)
@@ -142,6 +154,8 @@ Currently, quantization string is in same format as the one being passed in llam
142154
max_power: 9 # 2^9 = 512
143155
```
144156

157+
158+
145159
## Output
146160

147161
Results are saved to a CSV file in the specified output directory

benchmarks/microbenchmarks/benchmark_inference.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,28 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
5151
high_precision_dtype=config.high_precision_dtype,
5252
device=config.device,
5353
)
54+
# Copy base model for quantizing
55+
m_copy = deepcopy(base_model)
56+
57+
# Run benchmarks
58+
result = BenchmarkResult(config=config)
59+
60+
# Store result in model for memory profiling
61+
base_model._benchmark_result = result
62+
63+
# Run baseline benchmarking
64+
base_model = base_model.eval().to(config.device)
65+
if config.use_torch_compile:
66+
print("Compiling baseline model....")
67+
base_model = torch.compile(
68+
base_model, mode=config.torch_compile_mode, fullgraph=True
69+
)
70+
# Benchmark time to run an inference call for baseline model
71+
print("Benchmarking baseline inference.....")
72+
result.baseline_inference_time_in_ms = model_inference_time_in_ms(
73+
model=base_model, input_data=input_data
74+
)
5475

55-
# Use quantize_ to apply each quantization function to the model
56-
m_copy = deepcopy(base_model).eval().to(config.device)
5776
ao_base_config = string_to_config(
5877
config.quantization,
5978
config.sparsity,
@@ -79,24 +98,29 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
7998
pass # No quantization or sparsity specified, do nothing
8099
else:
81100
print("Quantizing model....")
101+
m_copy = m_copy.eval().to(config.device)
82102
quantize_(m_copy, ao_base_config)
83103

84104
if config.use_torch_compile:
85-
print("Compiling model....")
105+
print("Compiling quantized model....")
86106
m_copy = torch.compile(
87107
m_copy, mode=config.torch_compile_mode, fullgraph=True
88108
)
89109

90-
# Run benchmarks
91-
result = BenchmarkResult(config=config)
92110
# Store result in model for memory profiling
93111
m_copy._benchmark_result = result
94112

95113
# Benchmark time to run an inference call for quantized model
114+
print("Benchmarking quantized model.....")
96115
result.model_inference_time_in_ms = model_inference_time_in_ms(
97116
model=m_copy, input_data=input_data
98117
)
99118

119+
# Calculate speedup w.r.t. baseline
120+
result.speedup = round(
121+
result.baseline_inference_time_in_ms / result.model_inference_time_in_ms, 2
122+
)
123+
100124
# Run profiler if enabled
101125
if config.enable_profiler:
102126
print("Running profiler...")

benchmarks/microbenchmarks/benchmark_runner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,20 @@ def get_shapes_for_config(
7676
val2 = 2**power_of_2 + 2 ** (power_of_2 - 1)
7777
shapes.append((f"{name}_{idx * 2}", [val1, val1, val1]))
7878
shapes.append((f"{name}_{idx * 2 + 1}", [val2, val2, val2]))
79+
elif name == "small_sweep":
80+
# Generate a small sweep of shapes with increasing powers of 2 for M, K, N
81+
min_p2 = shape_config.get("min_power", 10) # 1024
82+
max_p2 = shape_config.get("max_power", 14) # 16,384
83+
counter = 0
84+
for M_p2 in range(min_p2, max_p2 + 1):
85+
M = 2**M_p2
86+
for K_p2 in range(min_p2, max_p2 + 1):
87+
K = 2**K_p2
88+
for N_p2 in range(min_p2, max_p2 + 1):
89+
N = 2**N_p2
90+
if M <= K <= N: # Ensure increasing order
91+
shapes.append((f"{name}_{counter}", [M, K, N]))
92+
counter += 1
7993
elif name == "sweep":
8094
# Generate a sweep of shapes with different powers of 2 for M, K, N
8195
min_p2 = shape_config.get("min_power", 8) # 256
@@ -202,7 +216,7 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
202216
print("----------------------------------------")
203217
try:
204218
print(
205-
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
219+
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity} for {config.shape_name}: {config.m, config.k, config.n}"
206220
)
207221
result = run_inference(config) # Pass the config object directly
208222
if result is not None: # Only add successful results

benchmarks/microbenchmarks/test/benchmark_config.yml

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,15 @@ benchmark_mode: "inference"
33
quantization_config_recipe_names: # Will run a baseline inference for model by default, without quantization for comparison
44
- "int8wo"
55
- "int8dq"
6-
- "float8dq"
6+
- "float8dq-tensor"
77
- "float8wo"
88
output_dir: "benchmarks/microbenchmarks/results"
99
model_params:
1010
- name: "small_bf16_linear"
1111
matrix_shapes:
12-
- name: "custom"
13-
shapes: [
14-
[1024, 1024, 1024], # [m, k, n]
15-
[2048, 4096, 1024],
16-
[4096, 4096, 1024]
17-
]
12+
- name: "small_sweep"
13+
min_power: 14
14+
max_power: 16
1815
high_precision_dtype: "torch.bfloat16"
1916
use_torch_compile: true
2017
torch_compile_mode: "max-autotune"
@@ -60,9 +57,6 @@ model_params:
6057
- name: "pow2_extended" # Example of using extended power of 2 shapes
6158
min_power: 10 # 1024
6259
max_power: 11 # 2048
63-
- name: "sweep" # Example of using sweep shapes (commented out as it generates many shapes)
64-
min_power: 8 # 256
65-
max_power: 9 # 512
6660
high_precision_dtype: "torch.bfloat16"
6761
use_torch_compile: true
6862
torch_compile_mode: "max-autotune"

benchmarks/microbenchmarks/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def __init__(
124124
):
125125
self.config = config
126126
self.output_dir = config.output_dir
127+
self.baseline_inference_time_in_ms = 0.0
127128
self.model_inference_time_in_ms = 0.0
129+
self.speedup = 0.0
128130
self.profiler_json_path: Optional[str] = None
129131
self.memory_profile_path: Optional[str] = None
130132
self.memory_visualization_path: Optional[str] = None
@@ -134,7 +136,9 @@ def to_dict(self) -> Dict[str, Any]:
134136
"""Convert result to dictionary for main function"""
135137
result_dict = {
136138
**self.config.to_dict(),
139+
"baseline_inference_time_in_ms": self.baseline_inference_time_in_ms,
137140
"model_inference_time_in_ms": self.model_inference_time_in_ms,
141+
"speedup": self.speedup,
138142
"profiler_json_path": self.profiler_json_path,
139143
"memory_profile_path": self.memory_profile_path,
140144
"memory_visualization_path": self.memory_visualization_path,
@@ -299,7 +303,7 @@ def model_inference_time_in_ms(model, input_data):
299303
input_data: Input data for the model
300304
301305
Returns:
302-
float: Median inference time in microseconds
306+
float: Median inference time in milliseconds
303307
"""
304308
# First run to trigger any compilation/lazy initialization
305309

@@ -315,8 +319,8 @@ def model_inference_time_in_ms(model, input_data):
315319
measurement = timer.timeit(number=100)
316320
res = measurement.mean
317321

318-
# Convert to microseconds
319-
return res * 1e6
322+
# Convert to milliseconds
323+
return (res * 1e6) / 1000 # Convert microseconds to milliseconds
320324

321325

322326
def clean_caches():
@@ -386,7 +390,9 @@ def print_results(results: List[BenchmarkResult]):
386390
result.config.quantization or "baseline",
387391
result.config.sparsity or "none",
388392
f"{result.config.shape_name} ({result.config.m}, {result.config.k}, {result.config.n})",
393+
f"{result.baseline_inference_time_in_ms:.2f}",
389394
f"{result.model_inference_time_in_ms:.2f}",
395+
f"{result.speedup:.2f}x",
390396
str(result.config.enable_profiler),
391397
]
392398

@@ -398,7 +404,9 @@ def print_results(results: List[BenchmarkResult]):
398404
"Quantization",
399405
"Sparsity",
400406
"Shape",
407+
"Baseline Inference Time (ms)",
401408
"Inference Time (ms)",
409+
"Speedup",
402410
"Profiler Enabled",
403411
]
404412

0 commit comments

Comments
 (0)