33import os
44import shlex
55import subprocess
6+ import time
7+ from multiprocessing import Process , Queue
68
79
810def is_hip ():
@@ -13,6 +15,22 @@ def is_hip():
1315 return False
1416
1517
18+ def worker (device_id , tune_gemm_results_file_path , task_queue ):
19+ env = os .environ .copy ()
20+ if is_hip ():
21+ env ["HIP_VISIBLE_DEVICES" ] = device_id
22+ else :
23+ env ["CUDA_VISIBLE_DEVICES" ] = device_id
24+ env ["HIPBLASLT_TUNING_FILE" ] = tune_gemm_results_file_path
25+
26+ while True :
27+ script = task_queue .get ()
28+ if script is None :
29+ break
30+ print (f"Device { device_id } processing: { script } " )
31+ subprocess .run (shlex .split (script ), check = True , env = env )
32+
33+
1634class OfflineTuneGemm :
1735
1836 def __init__ (self , dump_gemm_shape_file_path ):
@@ -61,26 +79,37 @@ def process_raw_dump(self, dump_gemm_shape_file_path):
6179 self .tune_script_dict_list .append (tune_script_dict )
6280 self .tune_script_list .append (tune_script )
6381
64- # TODO: use more device to tune
65- def tune (self , tune_gemm_results_file_path , device_id = "0" ):
66- env = os .environ .copy ()
67- if is_hip ():
68- env .update ({"HIP_VISIBLE_DEVICES" : device_id })
69- else :
70- env .update ({"CUDA_VISIBLE_DEVICES" : device_id })
71- env .update ({"HIPBLASLT_TUNING_FILE" : tune_gemm_results_file_path })
82+ def tune (self , tune_gemm_results_file_path , device_ids = ["0" ]):
83+ task_queue = Queue ()
84+ for script in self .tune_script_list :
85+ task_queue .put (script )
86+ for _ in device_ids :
87+ task_queue .put (None )
88+
89+ start_time = time .time ()
90+ processes = []
91+ for device_id in device_ids :
92+ p = Process (target = worker , args = (device_id , tune_gemm_results_file_path , task_queue ))
93+ p .start ()
94+ processes .append (p )
95+
96+ for p in processes :
97+ p .join ()
7298
73- for idx , script in enumerate (self .tune_script_list ):
74- print (f"Tune[{ idx } /{ len (self .tune_script_list )} ]:{ script } " )
75- subprocess .run (shlex .split (script ), env = env )
99+ end_time = time .time ()
100+ elapsed_time = end_time - start_time
101+ print (
102+ f"Tune cases Nums: { len (self .tune_script_list )} . Elapsed Time: { elapsed_time :.2f} s" ,
103+ )
76104
77105
78106if __name__ == "__main__" :
79107 parser = argparse .ArgumentParser ()
80- parser .add_argument ("--dump-shape-pathh " , type = str )
108+ parser .add_argument ("--dump-shape-path " , type = str )
81109 parser .add_argument ("--tune-result-path" , type = str )
82- # parser.add_argument("--device-id ", type=str , default="0" )
110+ parser .add_argument ("--num-devices " , type = int , default = 1 )
83111 args = parser .parse_args ()
112+ device_ids = [str (i ) for i in range (args .num_devices )]
84113
85114 tuner = OfflineTuneGemm (args .dump_shape_path )
86- tuner .tune (args .tune_result_path )
115+ tuner .tune (args .tune_result_path , device_ids )
0 commit comments