Skip to content

Commit 2614d95

Browse files
committed
[not4land] Local torchao benchmark
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 675fb8f commit 2614d95

File tree

8 files changed

+240
-16
lines changed

8 files changed

+240
-16
lines changed

manual_cron.sh

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
target_hour=18
2+
target_min=00
3+
while true
4+
do
5+
current_hour=$(date +%H)
6+
current_min=$(date +%M)
7+
if [ $current_hour -eq $target_hour ] && [ $current_min -eq $target_min ] ; then
8+
echo "Cron job started at $(date)"
9+
sh ~/local/cron_jobs/benchmark/cron_script.sh > ~/local/cron_jobs/benchmark/local_cron_log 2>~/local/cron_jobs/benchmark/local_cron_err
10+
echo "Cron job executed at $(date)"
11+
fi
12+
sleep 60
13+
done

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pytest
99
pytest-benchmark
1010
requests
1111
tabulate
12-
git+https://github.com/huggingface/pytorch-image-models.git@730b907
12+
# git+https://github.com/huggingface/pytorch-image-models.git@730b907
1313
# this version of transformers is required by linger-kernel
1414
# https://github.com/linkedin/Liger-Kernel/blob/main/pyproject.toml#L23
1515
transformers==4.44.2

upload_to_s3.py

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
import io
3+
import json
4+
from functools import lru_cache
5+
import boto3
6+
from typing import Any
7+
import gzip
8+
9+
@lru_cache
10+
def get_s3_resource() -> Any:
11+
return boto3.resource("s3")
12+
13+
def upload_to_s3(
14+
bucket_name: str,
15+
key: str,
16+
json_path: str,
17+
) -> None:
18+
print(f"Writing {json_path} documents to S3")
19+
data = []
20+
with open(f"{os.path.splitext(json_path)[0]}.json", "r") as f:
21+
for l in f.readlines():
22+
data.append(json.loads(l))
23+
24+
body = io.StringIO()
25+
for benchmark_entry in data:
26+
json.dump(benchmark_entry, body)
27+
body.write("\n")
28+
29+
try:
30+
get_s3_resource().Object(
31+
f"{bucket_name}",
32+
f"{key}",
33+
).put(
34+
Body=body.getvalue(),
35+
ContentType="application/json",
36+
)
37+
except e:
38+
print("fail to upload to s3:", e)
39+
return
40+
print("Done!")
41+
42+
if __name__ == "__main__":
43+
import argparse
44+
import datetime
45+
parser = argparse.ArgumentParser(description="Upload benchmark result json file to clickhouse")
46+
parser.add_argument("--json-path", type=str, help="json file path to upload to click house", required=True)
47+
args = parser.parse_args()
48+
today = datetime.date.today()
49+
today = datetime.datetime.combine(today, datetime.time.min)
50+
today_timestamp = str(int(today.timestamp()))
51+
print("Today timestamp:", today_timestamp)
52+
import subprocess
53+
# Execute the command and capture the output
54+
output = subprocess.check_output(['hostname', '-s'])
55+
# Decode the output from bytes to string
56+
hostname = output.decode('utf-8').strip()
57+
upload_to_s3("ossci-benchmarks", f"v3/pytorch/ao/{hostname}/torchbenchmark-torchbench-" + today_timestamp + ".json", args.json_path)

userbenchmark/dynamo/dynamobench/common.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
same,
6363
)
6464
from torch._logging.scribe import open_source_signpost
65+
from .utils import benchmark_and_write_json_result
6566

6667

6768
try:
@@ -555,8 +556,17 @@ def output_signpost(data, args, suite, error=None):
555556
)
556557

557558

558-
def nothing(f):
559-
return f
559+
def nothing(model_iter_fn):
560+
def _apply(module: torch.nn.Module, example_inputs: Any):
561+
if isinstance(example_inputs, dict):
562+
args = ()
563+
kwargs = example_inputs
564+
else:
565+
args = example_inputs
566+
kwargs = {}
567+
benchmark_and_write_json_result(module, args, kwargs, "noquant", "cuda", compile=False)
568+
model_iter_fn(module, example_inputs)
569+
return _apply
560570

561571

562572
@functools.lru_cache(None)
@@ -4147,8 +4157,9 @@ def get_example_inputs(self):
41474157
"int8dynamic",
41484158
"int8weightonly",
41494159
"int4weightonly",
4150-
"autoquant",
41514160
"noquant",
4161+
"autoquant",
4162+
"autoquant-all",
41524163
],
41534164
default=None,
41544165
help="Measure speedup of torchao quantization with TorchInductor baseline",

userbenchmark/dynamo/dynamobench/torchao_backend.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from typing import Any, Callable
22

33
import torch
4+
from .utils import benchmark_and_write_json_result
45

6+
_OUTPUT_JSON_PATH = "benchmark_results"
57

68
def setup_baseline():
79
from torchao.quantization.utils import recommended_inductor_config_setter
@@ -20,10 +22,21 @@ def torchao_optimize_ctx(quantization: str):
2022
quantize_,
2123
)
2224
from torchao.utils import unwrap_tensor_subclass
25+
import torchao
2326

2427
def inner(model_iter_fn: Callable):
2528
def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
2629
if getattr(module, "_quantized", None) is None:
30+
if quantization == "noquant":
31+
if isinstance(example_inputs, dict):
32+
args = ()
33+
kwargs = example_inputs
34+
else:
35+
args = example_inputs
36+
kwargs = {}
37+
38+
benchmark_and_write_json_result(module, args, kwargs, "noquant", "cuda")
39+
2740
if quantization == "int8dynamic":
2841
quantize_(
2942
module,
@@ -34,7 +47,30 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
3447
quantize_(module, int8_weight_only(), set_inductor_config=False)
3548
elif quantization == "int4weightonly":
3649
quantize_(module, int4_weight_only(), set_inductor_config=False)
37-
if quantization == "autoquant":
50+
if quantization == "autoquant-all":
51+
autoquant(module, error_on_unseen=False, set_inductor_config=False, qtensor_class_list=torchao.quantization.ALL_AUTOQUANT_CLASS_LIST)
52+
if isinstance(example_inputs, dict):
53+
module(**example_inputs)
54+
else:
55+
module(*example_inputs)
56+
from torchao.quantization.autoquant import AUTOQUANT_CACHE
57+
58+
if len(AUTOQUANT_CACHE) == 0:
59+
raise Exception( # noqa: TRY002`
60+
"NotAutoquantizable"
61+
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
62+
)
63+
64+
if isinstance(example_inputs, dict):
65+
args = ()
66+
kwargs = example_inputs
67+
else:
68+
args = example_inputs
69+
kwargs = {}
70+
71+
torchao.quantization.utils.recommended_inductor_config_setter()
72+
benchmark_and_write_json_result(module, args, kwargs, quantization, "cuda")
73+
elif quantization == "autoquant":
3874
autoquant(module, error_on_unseen=False, set_inductor_config=False)
3975
if isinstance(example_inputs, dict):
4076
module(**example_inputs)
@@ -47,6 +83,16 @@ def _torchao_apply(module: torch.nn.Module, example_inputs: Any):
4783
"NotAutoquantizable"
4884
f"Found no autoquantizable layers in model {type(module)}, stopping autoquantized run"
4985
)
86+
87+
if isinstance(example_inputs, dict):
88+
args = ()
89+
kwargs = example_inputs
90+
else:
91+
args = example_inputs
92+
kwargs = {}
93+
94+
torchao.quantization.utils.recommended_inductor_config_setter()
95+
benchmark_and_write_json_result(module, args, kwargs, quantization, "cuda")
5096
else:
5197
unwrap_tensor_subclass(module)
5298
setattr(module, "_quantized", True) # noqa: B010
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import json
2+
import torch
3+
import platform
4+
import os
5+
import time
6+
import datetime
7+
import hashlib
8+
9+
def get_arch_name() -> str:
10+
if torch.cuda.is_available():
11+
return torch.cuda.get_device_name()
12+
else:
13+
# This returns x86_64 or arm64 (for aarch64)
14+
return platform.machine()
15+
16+
17+
def write_json_result(output_json_path, headers, row, compile):
18+
"""
19+
Write the result into JSON format, so that it can be uploaded to the benchmark database
20+
to be displayed on OSS dashboard. The JSON format is defined at
21+
https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
22+
"""
23+
mapping_headers = {headers[i]: v for i, v in enumerate(row)}
24+
today = datetime.date.today()
25+
sha_hash = hashlib.sha256(str(today).encode("utf-8")).hexdigest()
26+
first_second = datetime.datetime.combine(today, datetime.time.min)
27+
workflow_id = int(first_second.timestamp())
28+
job_id = workflow_id + 1
29+
record = {
30+
"timestamp": int(time.time()),
31+
"schema_version": "v3",
32+
"name": "devvm local benchmark",
33+
"repo": "pytorch/ao",
34+
"head_branch": "main",
35+
"head_sha": sha_hash,
36+
"workflow_id": workflow_id,
37+
"run_attempt": 1,
38+
"job_id": job_id,
39+
"benchmark": {
40+
"name": "TorchAO benchmark",
41+
"mode": "inference",
42+
"dtype": mapping_headers["dtype"],
43+
"extra_info": {
44+
"device": mapping_headers["device"],
45+
"arch": mapping_headers["arch"],
46+
"min_sqnr": None,
47+
"compile": compile,
48+
},
49+
},
50+
"model": {
51+
"name": mapping_headers["name"],
52+
"type": "model",
53+
# TODO: make this configurable
54+
"origins": ["torchbench"],
55+
},
56+
"metric": {
57+
"name": mapping_headers["metric"],
58+
"benchmark_values": [mapping_headers["actual"]],
59+
"target_value": mapping_headers["target"],
60+
},
61+
}
62+
63+
with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f:
64+
print(json.dumps(record), file=f)
65+
66+
def benchmark_and_write_json_result(model, args, kwargs, quantization, device, compile=True):
67+
print(quantization + " run")
68+
from torchao.utils import benchmark_model, profiler_runner
69+
if compile:
70+
model = torch.compile(model, mode="max-autotune")
71+
benchmark_model(model, 20, args, kwargs)
72+
elapsed_time = benchmark_model(model, 100, args, kwargs)
73+
print("elapsed_time: ", elapsed_time, " milliseconds")
74+
75+
name = model._orig_mod.__class__.__name__
76+
headers = ["name", "dtype", "compile", "device", "arch", "metric", "actual", "target"]
77+
arch = get_arch_name()
78+
dtype = quantization
79+
performance_result = [name, dtype, compile, device, arch, "time_ms(avg)", elapsed_time, None]
80+
write_json_result(_OUTPUT_JSON_PATH, headers, performance_result)

userbenchmark/group_bench/configs/torch_ao.yaml

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,5 @@ metrics:
1010
test_group:
1111
test_batch_size_default:
1212
subgroup:
13-
- extra_args:
14-
- extra_args: --quantization int8dynamic
15-
- extra_args: --quantization int8weightonly
16-
- extra_args: --quantization int4weightonly
13+
- extra_args: --quantization noquant
14+
- extra_args: --quantization autoquant

userbenchmark/torchao/run.py

+26-7
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,18 @@
1212

1313

1414
def _get_ci_args(
15-
backend: str, modelset: str, dtype, mode: str, device: str, experiment: str
15+
quantization: str, modelset: str, dtype, mode: str, device: str, experiment: str
1616
) -> List[List[str]]:
1717
if modelset == "timm":
1818
modelset_full_name = "timm_models"
1919
else:
2020
modelset_full_name = modelset
21-
output_file_name = f"torchao_{backend}_{modelset_full_name}_{dtype}_{mode}_{device}_{experiment}.csv"
21+
output_file_name = f"torchao_{quantization}_{modelset_full_name}_{dtype}_{mode}_{device}_{experiment}.csv"
2222
ci_args = [
2323
"--progress",
2424
f"--{modelset}",
2525
"--quantization",
26-
f"{backend}",
26+
f"{quantization}",
2727
f"--{mode}",
2828
f"--{dtype}",
2929
f"--{experiment}",
@@ -32,16 +32,35 @@ def _get_ci_args(
3232
]
3333
return ci_args
3434

35+
def _get_eager_baseline_args(quantization: str, model_set: str, dtype, mode: str, device: str, experiment: str):
36+
if modelset == "timm":
37+
modelset_full_name = "timm_models"
38+
else:
39+
modelset_full_name = modelset
40+
output_file_name = f"torchao_{quantization}_{modelset_full_name}_{dtype}_{mode}_{device}_{experiment}_eager.csv"
41+
ci_args = [
42+
"--progress",
43+
f"--{modelset}",
44+
"--quantization",
45+
f"{quantization}",
46+
f"--{mode}",
47+
f"--{dtype}",
48+
f"--{experiment}",
49+
"--nothing",
50+
"--output",
51+
f"{str(OUTPUT_DIR.joinpath(output_file_name).resolve())}",
52+
]
53+
return ci_args
3554

3655
def _get_full_ci_args(modelset: str) -> List[List[str]]:
37-
backends = ["autoquant", "int8dynamic", "int8weightonly", "noquant"]
56+
quantizations = ["autoquant-all", "autoquant", "noquant"]
3857
modelset = [modelset]
3958
dtype = ["bfloat16"]
4059
mode = ["inference"]
4160
device = ["cuda"]
42-
experiment = ["performance", "accuracy"]
43-
cfgs = itertools.product(*[backends, modelset, dtype, mode, device, experiment])
44-
return [_get_ci_args(*cfg) for cfg in cfgs]
61+
experiment = ["performance"]
62+
cfgs = itertools.product(*[quantizations, modelset, dtype, mode, device, experiment])
63+
return [_get_ci_args(*cfg) for cfg in cfgs] + [_get_eager_baseline_args("noquant", modelset, dtype, mode, device, experiment)]
4564

4665

4766
def _get_output(pt2_args):

0 commit comments

Comments
 (0)