Skip to content

Commit 3766ed7

Browse files
authored
Add sparsity to benchmarking (#1917)
1 parent 09c2760 commit 3766ed7

7 files changed

+364
-22
lines changed

benchmarks/microbenchmarks/benchmark_inference.py

+27-4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
string_to_config,
2525
)
2626
from torchao.quantization import quantize_
27+
from torchao.sparsity.sparse_api import sparsify_
2728

2829

2930
def run(config: BenchmarkConfig) -> BenchmarkResult:
@@ -44,11 +45,33 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
4445

4546
# Use quantize_ to apply each quantization function to the model
4647
m_copy = deepcopy(base_model).eval().to(config.device)
47-
quantization_config = string_to_config(
48-
config.quantization, high_precision_dtype=config.high_precision_dtype
48+
ao_base_config = string_to_config(
49+
config.quantization,
50+
config.sparsity,
51+
high_precision_dtype=config.high_precision_dtype,
4952
)
50-
if quantization_config is not None:
51-
quantize_(m_copy, quantization_config)
53+
54+
# Check if sparsity is requested and if the device is CUDA (sparsity operations require CUDA)
55+
is_cuda = config.device == "cuda" and torch.cuda.is_available()
56+
57+
if config.sparsity is not None and (
58+
config.quantization is None or "baseline" in config.quantization
59+
):
60+
if is_cuda:
61+
print(f"Applying {config.sparsity} sparsity to model")
62+
sparsify_(m_copy, ao_base_config)
63+
else:
64+
print(
65+
f"Warning: Skipping {config.sparsity} sparsity as it requires CUDA, but device is {config.device}"
66+
)
67+
elif config.sparsity is None and (
68+
config.quantization is None or "baseline" in config.quantization
69+
):
70+
pass # No quantization or sparsity specified, do nothing
71+
else:
72+
print("Quantizing model....")
73+
quantize_(m_copy, ao_base_config)
74+
5275
if config.use_torch_compile:
5376
print("Compiling model....")
5477
m_copy = torch.compile(m_copy, mode=config.torch_compile_mode, fullgraph=True)

benchmarks/microbenchmarks/benchmark_runner.py

+63-8
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import argparse
2323
from itertools import product
24-
from typing import Any, Dict, List, Tuple
24+
from typing import Any, Dict, List, Optional, Set, Tuple
2525

2626
import yaml
2727

@@ -68,6 +68,53 @@ def get_param_combinations(model_param):
6868
return shapes, base_params
6969

7070

71+
def get_quantization_sparsity_recipes(
72+
quantization_recipes: List[str], sparsity_recipes: List[str]
73+
) -> Set[Tuple[str, Optional[str]]]:
74+
"""Generate valid quantization and sparsity recipes.
75+
76+
Args:
77+
quantization_recipes: List of quantization recipes
78+
sparsity_recipes: List of sparsity recipes
79+
80+
Returns:
81+
Set of tuples containing (quantization_recipe, sparsity_recipe)
82+
For block sparsity, quantization is always "baseline"
83+
All quantization techniques are also run without sparsity
84+
"""
85+
config_recipes = set()
86+
87+
# Always include baseline without sparsity
88+
config_recipes.add(("baseline", None))
89+
90+
# Add all quantization techniques without sparsity
91+
for quant_config in quantization_recipes:
92+
config_recipes.add((quant_config, None))
93+
94+
# Process combinations of quantization and sparsity
95+
for sparse_config in sparsity_recipes:
96+
if sparse_config is None:
97+
# Skip None sparsity as we've already added all quantization techniques without sparsity
98+
continue
99+
elif "block" in sparse_config:
100+
# For block sparsity, only pair with baseline quantization
101+
config_recipes.add(("baseline", sparse_config))
102+
elif "semi" in sparse_config or "2:4" in sparse_config:
103+
# For semi-sparse, only pair with compatible quantization methods
104+
for quant_config in quantization_recipes:
105+
if (
106+
"marlin" in quant_config
107+
or "int8dq" in quant_config
108+
or "float8dq" in quant_config
109+
or quant_config == "baseline"
110+
):
111+
config_recipes.add((quant_config, sparse_config))
112+
else:
113+
raise ValueError(f"Invalid sparsity recipe: {sparse_config}")
114+
115+
return config_recipes
116+
117+
71118
def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig]:
72119
"""Load benchmark configurations from CLI arguments and YAML file."""
73120
with open(cli_args.config, "r") as f:
@@ -78,24 +125,29 @@ def load_benchmark_configs(cli_args: argparse.Namespace) -> List[BenchmarkConfig
78125

79126
# Create all possible combinations
80127
configs = []
128+
quantization_sparsity_recipes = get_quantization_sparsity_recipes(
129+
config.get("quantization_config_recipe_names", []),
130+
config.get("sparsity_config_recipe_names", []),
131+
)
81132
for model_param in config["model_params"]:
82133
shapes, params = get_param_combinations(model_param)
83134

84135
# Create configs for all combinations
85-
for quant_config, (shape_name, shape) in product(
86-
config.get("quantization_config_recipe_names", ["baseline"]), shapes
136+
for (quant_config, sparse_config), (shape_name, shape) in product(
137+
quantization_sparsity_recipes,
138+
shapes,
87139
):
88140
configs.append(
89141
BenchmarkConfig(
90142
quantization=quant_config,
143+
sparsity=sparse_config,
91144
params=params,
92145
shape_name=shape_name,
93146
shape=shape,
94147
output_dir=output_dir,
95148
benchmark_mode=benchmark_mode,
96149
)
97150
)
98-
99151
return configs
100152

101153

@@ -104,14 +156,17 @@ def run_inference_benchmarks_from_config(configs: List[BenchmarkConfig]) -> None
104156
from benchmarks.microbenchmarks.benchmark_inference import run as run_inference
105157

106158
results = []
107-
print("Benchmarking Inference ......")
159+
print("----------------- RUNNING BENCHMARKS FOR INFERENCE -----------------------")
108160
for config in configs:
161+
print("----------------------------------------")
109162
try:
110-
print(f"Running: {config.name}")
163+
print(
164+
f"Running: {config.name} for Quantization: {config.quantization} and Sparsity: {config.sparsity}"
165+
)
111166
result = run_inference(config) # Pass the config object directly
112167
results.append(result)
113-
except Exception as e:
114-
print(f"Error running benchmark {config.name}: {e}")
168+
except Exception:
169+
print(f"Error running benchmark {config.name}")
115170
continue
116171

117172
# Add results to csv

benchmarks/microbenchmarks/test/benchmark_config.yml

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# Sample configuration for inference benchmarks
22
benchmark_mode: "inference"
33
quantization_config_recipe_names:
4-
- "baseline"
4+
# Will run a baseline inference for model by default, without quantization for comparison
55
- "int4wo-32"
6-
- "int4wo-128"
6+
- "marlin"
7+
sparsity_config_recipe_names:
8+
# Will run a baseline inference for model by default, without sparsity for comparison
9+
- "semi-sparse"
10+
- "block"
711
output_dir: "benchmarks/microbenchmarks/results"
812
model_params:
913
- name: "small_bf16_linear"

benchmarks/microbenchmarks/test/test_benchmark_inference.py

+66-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import tempfile
77
import unittest
8+
from unittest.mock import patch
89

910
from benchmarks.microbenchmarks.benchmark_inference import run
1011
from benchmarks.microbenchmarks.utils import BenchmarkConfig, BenchmarkResult
@@ -17,6 +18,7 @@ def setUp(self):
1718

1819
self.config = BenchmarkConfig(
1920
quantization="baseline",
21+
sparsity="semi-sparse",
2022
params={
2123
"high_precision_dtype": "torch.float32",
2224
"use_torch_compile": False,
@@ -35,11 +37,74 @@ def tearDown(self):
3537

3638
shutil.rmtree(self.temp_dir)
3739

38-
def test_run_inference(self):
40+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
41+
def test_run_inference(self, mock_string_to_config):
42+
# Mock string_to_config to return a valid config
43+
from torchao.sparsity.sparse_api import SemiSparseWeightConfig
44+
45+
mock_string_to_config.return_value = SemiSparseWeightConfig()
46+
3947
result = run(self.config)
4048
self.assertIsInstance(result, BenchmarkResult)
4149
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
4250

51+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
52+
def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
53+
"""Test running inference with sparsity configurations"""
54+
# Mock string_to_config to return valid configs
55+
from torchao.dtypes import MarlinSparseLayout
56+
from torchao.quantization import Int4WeightOnlyConfig
57+
58+
# Test with semi-sparse config
59+
mock_string_to_config.return_value = Int4WeightOnlyConfig(
60+
layout=MarlinSparseLayout()
61+
)
62+
config = BenchmarkConfig(
63+
quantization="marlin",
64+
sparsity="semi-sparse",
65+
params={
66+
"high_precision_dtype": "torch.float32",
67+
"use_torch_compile": False,
68+
"device": "cpu",
69+
"model_type": "linear",
70+
},
71+
shape_name="custom",
72+
shape=[64, 64, 64], # Use dimensions divisible by 64
73+
output_dir=self.temp_dir,
74+
benchmark_mode="inference",
75+
)
76+
result = run(config)
77+
self.assertIsInstance(result, BenchmarkResult)
78+
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
79+
80+
@patch("benchmarks.microbenchmarks.benchmark_inference.string_to_config")
81+
def test_run_inference_with_block_sparsity(self, mock_string_to_config):
82+
"""Test running inference with sparsity configurations"""
83+
# Mock string_to_config to return valid configs
84+
from torchao.sparsity.sparse_api import (
85+
BlockSparseWeightConfig,
86+
)
87+
88+
# Test with block sparsity
89+
mock_string_to_config.return_value = BlockSparseWeightConfig()
90+
config = BenchmarkConfig(
91+
quantization="baseline",
92+
sparsity="block",
93+
params={
94+
"high_precision_dtype": "torch.float32",
95+
"use_torch_compile": False,
96+
"device": "cpu",
97+
"model_type": "linear",
98+
},
99+
shape_name="custom",
100+
shape=[64, 64, 64], # Use dimensions divisible by 64
101+
output_dir=self.temp_dir,
102+
benchmark_mode="inference",
103+
)
104+
result = run(config)
105+
self.assertIsInstance(result, BenchmarkResult)
106+
self.assertTrue(hasattr(result, "model_inference_time_in_ms"))
107+
43108

44109
if __name__ == "__main__":
45110
unittest.main()

benchmarks/microbenchmarks/test/test_benchmark_runner.py

+104
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from benchmarks.microbenchmarks.benchmark_runner import (
1515
get_param_combinations,
16+
get_quantization_sparsity_recipes,
1617
get_shapes_for_config,
1718
load_benchmark_configs,
1819
run_inference_benchmarks_from_config,
@@ -88,6 +89,109 @@ def test_run_inference_benchmarks_from_config(self):
8889
results_file = Path(self.temp_dir) / "results.csv"
8990
self.assertTrue(results_file.exists())
9091

92+
def test_get_quantization_sparsity_recipes(self):
93+
"""Test generation of valid quantization and sparsity recipe combinations"""
94+
# Test basic combinations
95+
quant_recipes = ["baseline", "int8wo"]
96+
sparse_recipes = [None, "semi-sparse"]
97+
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
98+
self.assertIn(("baseline", None), recipes)
99+
self.assertIn(("int8wo", None), recipes)
100+
self.assertIn(("baseline", "semi-sparse"), recipes)
101+
102+
# Test marlin with semi-sparse
103+
quant_recipes = ["marlin", "baseline"]
104+
sparse_recipes = [None, "semi-sparse"]
105+
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
106+
self.assertIn(("marlin", "semi-sparse"), recipes)
107+
self.assertIn(("baseline", None), recipes)
108+
109+
# Test block sparsity
110+
quant_recipes = ["baseline"]
111+
sparse_recipes = [None, "block"]
112+
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
113+
self.assertIn(("baseline", "block"), recipes)
114+
115+
def test_none_string_raises_error(self):
116+
"""Test that passing 'None' as a string raises an error"""
117+
quant_recipes = ["baseline"]
118+
sparse_recipes = ["None"] # "None" as a string should raise an error
119+
with self.assertRaises(ValueError):
120+
get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
121+
122+
def test_block_sparsity_with_quantization(self):
123+
"""Test that block sparsity is only paired with baseline quantization"""
124+
quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"]
125+
sparse_recipes = ["block"]
126+
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
127+
128+
# Block sparsity should only be paired with baseline
129+
self.assertIn(("baseline", "block"), recipes)
130+
self.assertNotIn(("int8wo", "block"), recipes)
131+
self.assertNotIn(("int4wo", "block"), recipes)
132+
self.assertNotIn(("marlin", "block"), recipes)
133+
134+
# All quantization techniques should be run without sparsity
135+
self.assertIn(("baseline", None), recipes)
136+
self.assertIn(("int8wo", None), recipes)
137+
self.assertIn(("int4wo", None), recipes)
138+
self.assertIn(("marlin", None), recipes)
139+
140+
def test_all_quantization_without_sparsity(self):
141+
"""Test that all quantization techniques are run without sparsity"""
142+
quant_recipes = ["baseline", "int8wo", "int4wo", "marlin"]
143+
sparse_recipes = [None, "semi-sparse", "block"]
144+
recipes = get_quantization_sparsity_recipes(quant_recipes, sparse_recipes)
145+
146+
# All quantization techniques should be run without sparsity
147+
for quant in quant_recipes:
148+
self.assertIn((quant, None), recipes)
149+
150+
@patch(
151+
"benchmarks.microbenchmarks.benchmark_runner.get_quantization_sparsity_recipes"
152+
)
153+
def test_load_benchmark_configs_with_sparsity(self, mock_get_recipes):
154+
"""Test loading benchmark configs with sparsity options"""
155+
# Mock get_quantization_sparsity_recipes to return a valid set of recipes
156+
mock_get_recipes.return_value = {("baseline", None), ("marlin", "semi-sparse")}
157+
158+
test_config = {
159+
"benchmark_mode": "inference",
160+
"quantization_config_recipe_names": ["baseline", "marlin"],
161+
"sparsity_config_recipe_names": [
162+
None,
163+
"semi-sparse",
164+
], # Use None instead of "None"
165+
"output_dir": self.temp_dir,
166+
"model_params": [
167+
{
168+
"matrix_shapes": [
169+
{"name": "custom", "shapes": [[1024, 1024, 1024]]}
170+
],
171+
"high_precision_dtype": "torch.bfloat16",
172+
"device": "cpu",
173+
"model_type": "linear",
174+
}
175+
],
176+
}
177+
178+
config_path = Path(self.temp_dir) / "test_sparsity_config.yml"
179+
with open(config_path, "w") as f:
180+
yaml.dump(test_config, f)
181+
182+
configs = load_benchmark_configs(argparse.Namespace(config=str(config_path)))
183+
184+
# Check that we get configs for baseline and marlin with appropriate sparsity
185+
self.assertTrue(
186+
any(c.quantization == "baseline" and c.sparsity is None for c in configs)
187+
)
188+
self.assertTrue(
189+
any(
190+
c.quantization == "marlin" and c.sparsity == "semi-sparse"
191+
for c in configs
192+
)
193+
)
194+
91195

92196
if __name__ == "__main__":
93197
unittest.main()

0 commit comments

Comments
 (0)