Skip to content

Commit 7003718

Browse files
hipblaslt gemm tuning example (#16)
1 parent afe51b7 commit 7003718

File tree

2 files changed

+139
-0
lines changed

2 files changed

+139
-0
lines changed

examples/offline_tune/README.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Offline Tune
2+
3+
4+
## 1. GEMM Tune
5+
6+
Use the `hipblaslt-bench` tool to perform GEMM tuning.
7+
8+
`hipblaslt-bench` is usually located under `/opt/rocm/bin`. However, if it's not available in some environments/docker, you'll need to reinstall hipblaslt.
9+
10+
11+
### Install Hipblaslt (Optional)
12+
You can reference: https://github.com/ROCm/hipBLASLt?tab=readme-ov-file#build-and-install
13+
14+
If only run MI300X, you can use the following command for a quick compilation, reducing the compilation time to under 2 hours.
15+
```
16+
./install.sh -idc --logic-yaml-filter gfx942/*/* -a gfx942 -j 256 --build_dir build
17+
```
18+
19+
20+
### Step 1: Dump Shape
21+
* Set the Hipblaslt ENV.
22+
* Run Train code.
23+
* Unset ENV.
24+
* The gemm shape will be dumped into `dump_gemm_shapes.txt`.
25+
* Note: If just to dump shape, in most cases, there's no need to train for many iters—just a few should be enough, as each step uses the same shape.
26+
```
27+
export HIPBLASLT_LOG_MASK=32
28+
export HIPBLASLT_LOG_FILE=dump_gemm_shapes.txt
29+
30+
./run_your_code
31+
32+
unset HIPBLASLT_LOG_MASK
33+
unset HIPBLASLT_LOG_FILE
34+
```
35+
### Step 2: Tuning
36+
Run `offline_tune_gemm.py` and save tuned results in `tune_gemm_results.txt`
37+
```
38+
python3 offline_tune_gemm.py \
39+
--dump-shape-path /PATH/TO/dump_gemm_shapes.txt \
40+
--tune-result-path /PATH/TO/tune_gemm_results.txt
41+
```
42+
43+
### Step 3: Use tuned results to Train
44+
* Set the results ENV.
45+
* Start your tasks.
46+
```
47+
export HIPBLASLT_TUNING_OVERRIDE_FILE=tune_gemm_results.txt
48+
./run_your_code
49+
```
50+
51+
# Reference
52+
53+
https://rocm.blogs.amd.com/artificial-intelligence/gemm_blog/README.html
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import argparse
2+
import copy
3+
import os
4+
import shlex
5+
import subprocess
6+
7+
8+
def is_hip():
9+
import torch
10+
11+
if torch.version.hip is not None:
12+
return True
13+
return False
14+
15+
16+
class OfflineTuneGemm:
17+
18+
def __init__(self, dump_gemm_shape_file_path):
19+
self.HIPBLIST_BENCH = "/opt/rocm/bin/hipblaslt-bench "
20+
self.ROTATING_BUFFER = 512
21+
self.RUN_NUMS = 20
22+
self.REQUESTED_SOLUTION = -1
23+
self.SKIP_LOW_SOLUTION = 0.7
24+
25+
self.src_script_dict_list = []
26+
self.src_script_list = []
27+
self.tune_script_dict_list = []
28+
self.tune_script_list = []
29+
self.process_raw_dump(dump_gemm_shape_file_path)
30+
31+
def process_raw_dump(self, dump_gemm_shape_file_path):
32+
with open(dump_gemm_shape_file_path, "r", encoding="utf-8") as file:
33+
lines = file.readlines()
34+
lines = list(set(lines))
35+
lines.sort()
36+
37+
for line in lines:
38+
line = line.strip().split(" ")
39+
line = [item for item in line if item.strip()]
40+
if line[0] == "hipblaslt-bench":
41+
src_script_dict = {}
42+
for item in line[1:]:
43+
if item.startswith("--") or item.startswith("-"):
44+
key = item
45+
else:
46+
src_script_dict[key] = item
47+
# src script
48+
src_script_dict["--rotating"] = self.ROTATING_BUFFER
49+
src_script_dict["--cold_iters"] = self.RUN_NUMS
50+
src_script_dict["--iters"] = self.RUN_NUMS
51+
src_script = self.HIPBLIST_BENCH + " ".join(f"{k} {v}" for k, v in src_script_dict.items())
52+
self.src_script_dict_list.append(src_script_dict)
53+
self.src_script_list.append(src_script)
54+
# tune script
55+
tune_script_dict = copy.deepcopy(src_script_dict)
56+
del tune_script_dict["--algo_method"]
57+
del tune_script_dict["--solution_index"]
58+
tune_script_dict["--requested_solution"] = self.REQUESTED_SOLUTION
59+
tune_script_dict["--skip_slow_solution_ratio"] = self.SKIP_LOW_SOLUTION
60+
tune_script = self.HIPBLIST_BENCH + " ".join(f"{k} {v}" for k, v in tune_script_dict.items())
61+
self.tune_script_dict_list.append(tune_script_dict)
62+
self.tune_script_list.append(tune_script)
63+
64+
# TODO: use more device to tune
65+
def tune(self, tune_gemm_results_file_path, device_id="0"):
66+
env = os.environ.copy()
67+
if is_hip():
68+
env.update({"HIP_VISIBLE_DEVICES": device_id})
69+
else:
70+
env.update({"CUDA_VISIBLE_DEVICES": device_id})
71+
env.update({"HIPBLASLT_TUNING_FILE": tune_gemm_results_file_path})
72+
73+
for idx, script in enumerate(self.tune_script_list):
74+
print(f"Tune[{idx}/{len(self.tune_script_list)}]:{script}")
75+
subprocess.run(shlex.split(script), env=env)
76+
77+
78+
if __name__ == "__main__":
79+
parser = argparse.ArgumentParser()
80+
parser.add_argument("--dump-shape-pathh", type=str)
81+
parser.add_argument("--tune-result-path", type=str)
82+
# parser.add_argument("--device-id", type=str, default="0")
83+
args = parser.parse_args()
84+
85+
tuner = OfflineTuneGemm(args.dump_shape_path)
86+
tuner.tune(args.tune_result_path)

0 commit comments

Comments
 (0)