66 load_profiler_data ,
77 xla_module_metadata ,
88)
9+ from nsys_jax .protobuf import HloProto , HloProtoSet
910import pathlib
1011
1112
12- def write_pbtxt (outdir : pathlib .Path , series_ms , hlo_module ):
13- mod_proto = hlo_module .proto ().hlo_module
14- fingerprint = mod_proto .frontend_attributes .map ["fingerprint_before_lhs" ]
13+ def get_scheduling_name (module : HloProto , name : str ) -> str :
14+ _ , inst = module .find_instruction (name )
15+ return inst .proto ().metadata .scheduling_name
16+
17+
18+ def write_pbtxt (outdir : pathlib .Path , series_ms , hlo_module_set : HloProtoSet ):
19+ fingerprint = hlo_module_set .unique_result (
20+ lambda mod : mod .proto ().hlo_module .frontend_attributes .map [
21+ "fingerprint_before_lhs"
22+ ]
23+ )
1524 outdir .mkdir (exist_ok = True )
1625 fp_fname = f"{ fingerprint } .pbtxt"
1726 null_names = 0
1827 with open (outdir / fp_fname , "w" ) as ofile :
1928 for name , cost_ms in series_ms .items ():
20- comp , inst = hlo_module .find_instruction (name )
21- scheduling_name = inst .proto ().metadata .scheduling_name
29+ scheduling_name = hlo_module_set .unique_result (
30+ lambda mod : get_scheduling_name (mod , name )
31+ )
2232 null_names += len (scheduling_name ) == 0
2333 ofile .write (
2434 f'costs {{ name: "{ scheduling_name } " cost_us: { cost_ms * 1000 :.1f} }}\n '
@@ -61,7 +71,7 @@ def main():
6171 for row in module_ranking .itertuples ():
6272 print (f"Processing module { row .Name } ({ row .Index } )" )
6373 try :
64- hlo_module = xla_module_metadata (row .Index , prefix = args .prefix )
74+ hlo_set = xla_module_metadata (row .Index , policy = "all" , prefix = args .prefix )
6575 except Exception as e :
6676 print (f"Skipping due to: { e } " )
6777 continue
@@ -73,7 +83,7 @@ def main():
7383 write_pbtxt (
7484 pathlib .Path ("./maxcomm_mincompute" ),
7585 min_compute_max_comm (thunk_df .groupby ("Name" )),
76- hlo_module ,
86+ hlo_set ,
7787 )
7888
7989
0 commit comments