Skip to content

Commit 2ada0a8

Browse files
hipblaslt tune tool support multi-devices to tune (#17)
1 parent cc6985e commit 2ada0a8

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

examples/offline_tune/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ Run `offline_tune_gemm.py` and save tuned results in `tune_gemm_results.txt`
3737
```
3838
python3 offline_tune_gemm.py \
3939
--dump-shape-path /PATH/TO/dump_gemm_shapes.txt \
40-
--tune-result-path /PATH/TO/tune_gemm_results.txt
40+
--tune-result-path /PATH/TO/tune_gemm_results.txt \
41+
--num-devices 8
4142
```
4243

4344
### Step 3: Use tuned results to Train

examples/offline_tune/offline_tune_gemm.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
import shlex
55
import subprocess
6+
import time
7+
from multiprocessing import Process, Queue
68

79

810
def is_hip():
@@ -13,6 +15,22 @@ def is_hip():
1315
return False
1416

1517

18+
def worker(device_id, tune_gemm_results_file_path, task_queue):
19+
env = os.environ.copy()
20+
if is_hip():
21+
env["HIP_VISIBLE_DEVICES"] = device_id
22+
else:
23+
env["CUDA_VISIBLE_DEVICES"] = device_id
24+
env["HIPBLASLT_TUNING_FILE"] = tune_gemm_results_file_path
25+
26+
while True:
27+
script = task_queue.get()
28+
if script is None:
29+
break
30+
print(f"Device {device_id} processing: {script}")
31+
subprocess.run(shlex.split(script), check=True, env=env)
32+
33+
1634
class OfflineTuneGemm:
1735

1836
def __init__(self, dump_gemm_shape_file_path):
@@ -61,26 +79,37 @@ def process_raw_dump(self, dump_gemm_shape_file_path):
6179
self.tune_script_dict_list.append(tune_script_dict)
6280
self.tune_script_list.append(tune_script)
6381

64-
# TODO: use more device to tune
65-
def tune(self, tune_gemm_results_file_path, device_id="0"):
66-
env = os.environ.copy()
67-
if is_hip():
68-
env.update({"HIP_VISIBLE_DEVICES": device_id})
69-
else:
70-
env.update({"CUDA_VISIBLE_DEVICES": device_id})
71-
env.update({"HIPBLASLT_TUNING_FILE": tune_gemm_results_file_path})
82+
def tune(self, tune_gemm_results_file_path, device_ids=["0"]):
83+
task_queue = Queue()
84+
for script in self.tune_script_list:
85+
task_queue.put(script)
86+
for _ in device_ids:
87+
task_queue.put(None)
88+
89+
start_time = time.time()
90+
processes = []
91+
for device_id in device_ids:
92+
p = Process(target=worker, args=(device_id, tune_gemm_results_file_path, task_queue))
93+
p.start()
94+
processes.append(p)
95+
96+
for p in processes:
97+
p.join()
7298

73-
for idx, script in enumerate(self.tune_script_list):
74-
print(f"Tune[{idx}/{len(self.tune_script_list)}]:{script}")
75-
subprocess.run(shlex.split(script), env=env)
99+
end_time = time.time()
100+
elapsed_time = end_time - start_time
101+
print(
102+
f"Tune cases Nums: {len(self.tune_script_list)}. Elapsed Time: {elapsed_time:.2f} s",
103+
)
76104

77105

78106
if __name__ == "__main__":
79107
parser = argparse.ArgumentParser()
80-
parser.add_argument("--dump-shape-pathh", type=str)
108+
parser.add_argument("--dump-shape-path", type=str)
81109
parser.add_argument("--tune-result-path", type=str)
82-
# parser.add_argument("--device-id", type=str, default="0")
110+
parser.add_argument("--num-devices", type=int, default=1)
83111
args = parser.parse_args()
112+
device_ids = [str(i) for i in range(args.num_devices)]
84113

85114
tuner = OfflineTuneGemm(args.dump_shape_path)
86-
tuner.tune(args.tune_result_path)
115+
tuner.tune(args.tune_result_path, device_ids)

0 commit comments

Comments
 (0)