11#!/usr/bin/env python
22import argparse
3+ import csv
34from collections import defaultdict
5+
46from nsys_jax import (
57 align_profiler_data_timestamps ,
68 apply_warmup_heuristics ,
79 ensure_compiled_protos_are_importable ,
810 load_profiler_data ,
911)
1012from math import sqrt
13+ from prettytable import PrettyTable
1114import pathlib
1215from uncertainties import ufloat # type: ignore
1316
1417
15- def main ():
16- parser = argparse .ArgumentParser (
17- description = "Summarise communication in an nsys-jax report"
18- )
19- parser .add_argument ("prefix" , type = pathlib .Path )
20- args = parser .parse_args ()
21- # Make sure that the .proto files under protos/ have been compiled to .py, and
22- # that those generated .py files are importable.
23- ensure_compiled_protos_are_importable (prefix = args .prefix )
24- # Load the profiler data; the compilation part is needed for the warmup heuristics
25- all_data = load_profiler_data (args .prefix , frames = {"communication" , "compile" })
26- # Align timestamps
27- all_data , alignment_metadata = align_profiler_data_timestamps (all_data )
28- # TODO: make this pretty
29- # print(alignment_metadata)
30- # Partition the profile data into initialisation and steady-state running
31- _ , steady_state = apply_warmup_heuristics (all_data )
32- assert len (steady_state .communication ), (
33- "Communication summary was requested but no steady-state communication was "
34- "identified."
35- )
18+ def process_communication_data (steady_state ):
3619 collective_types = set ()
3720 summary_data = defaultdict (dict )
3821 for (collective , message_size ), df in steady_state .communication .groupby (
@@ -52,7 +35,10 @@ def main():
5235 summary_data [message_size ][collective ] = ufloat (
5336 bandwidth .mean (), bandwidth .std () / sqrt (len (bandwidth ))
5437 )
55- collective_types = sorted (collective_types )
38+ return sorted (collective_types ), summary_data
39+
40+
41+ def print_bandwidth_table (collective_types , summary_data ):
5642 collective_widths = {
5743 collective : max (
5844 len (collective ),
@@ -96,5 +82,137 @@ def format_bandwidth(data, collective):
9682 )
9783
9884
85+ def process_hidden_ms_to_total_ms (steady_state ):
86+ if steady_state .communication ["ProjDurHiddenMs" ].sum () == 0 :
87+ return None , None
88+
89+ collective_types = set ()
90+ summary_data = defaultdict (dict )
91+ for collective , df in steady_state .communication .groupby (["Collective" ]):
92+ collective_types .add (collective )
93+ mean_dur_hidden_ms_to_total_ms = (
94+ df ["ProjDurHiddenMs" ] / (df ["ProjDurMs" ] + df ["ProjDurHiddenMs" ])
95+ ).mean ()
96+ summary_data [collective ] = mean_dur_hidden_ms_to_total_ms
97+ return collective_types , summary_data
98+
99+
100+ def print_hidden_ms_to_total_ms_table (
101+ collective_types , summary_data , overall_hidden_ms_to_total_ms
102+ ):
103+ table = PrettyTable ()
104+ table .field_names = ["Collective" , "Mean HiddenToTotalMs" ]
105+
106+ for collective in collective_types :
107+ mean_value = summary_data [collective ]
108+ table .add_row ([collective [0 ], mean_value ])
109+
110+ print (table )
111+ print ("Overall HiddenMs to TotalMs:" , overall_hidden_ms_to_total_ms )
112+
113+
114+ def calculate_overall_hidden_ms_to_total_ms (steady_state ):
115+ if steady_state .communication ["ProjDurHiddenMs" ].sum () == 0 :
116+ return None
117+
118+ overall_hidden_ms_to_total_ms = (
119+ steady_state .communication ["ProjDurHiddenMs" ].sum ()
120+ / (
121+ steady_state .communication ["ProjDurMs" ]
122+ + steady_state .communication ["ProjDurHiddenMs" ]
123+ ).sum ()
124+ )
125+ return overall_hidden_ms_to_total_ms
126+
127+
128+ def write_to_csv (
129+ collective_types ,
130+ bandwidth_summary ,
131+ hidden_to_total_summary ,
132+ overall_hidden_ms_to_total_ms ,
133+ output_file ,
134+ ):
135+ with open (output_file , "w" , newline = "" ) as csvfile :
136+ writer = csv .writer (csvfile )
137+
138+ # Write bandwidth table
139+ writer .writerow (["Bandwidth Table" ])
140+ writer .writerow (["Size [B]" ] + list (collective_types ))
141+ for message_size in sorted (bandwidth_summary .keys ()):
142+ row = [message_size ]
143+ for collective in collective_types :
144+ if collective in bandwidth_summary [message_size ]:
145+ row .append (f"{ bandwidth_summary [message_size ][collective ]:S} " )
146+ else :
147+ row .append ("-" )
148+ writer .writerow (row )
149+
150+ writer .writerow ([]) # Empty row for separation
151+
152+ # Write hidden to total table if data is available
153+ if hidden_to_total_summary is not None :
154+ writer .writerow (["HiddenMs to TotalMs Table" ])
155+ writer .writerow (["Collective" , "Mean HiddenToTotalMs" ])
156+ for collective in hidden_to_total_summary :
157+ writer .writerow ([collective [0 ], hidden_to_total_summary [collective ]])
158+
159+ writer .writerow ([]) # Empty row for separation
160+
161+ if overall_hidden_ms_to_total_ms is not None :
162+ writer .writerow (
163+ ["Overall HiddenMs to TotalMs" , overall_hidden_ms_to_total_ms ]
164+ )
165+
166+
167+ def main ():
168+ parser = argparse .ArgumentParser (
169+ description = "Summarise communication in an nsys-jax report"
170+ )
171+ parser .add_argument ("prefix" , type = pathlib .Path )
172+ args = parser .parse_args ()
173+
174+ # Make sure that the .proto files under protos/ have been compiled to .py, and
175+ # that those generated .py files are importable.
176+ ensure_compiled_protos_are_importable (prefix = args .prefix )
177+ # Load the profiler data; the compilation part is needed for the warmup heuristics
178+ all_data = load_profiler_data (args .prefix , frames = {"communication" , "compile" })
179+ # Align timestamps
180+ all_data , alignment_metadata = align_profiler_data_timestamps (all_data )
181+ # TODO: make this pretty
182+ # print(alignment_metadata)
183+ # Partition the profile data into initialisation and steady-state running
184+ _ , steady_state = apply_warmup_heuristics (all_data )
185+
186+ assert len (steady_state .communication ), (
187+ "Communication summary was requested but no steady-state communication was "
188+ "identified."
189+ )
190+
191+ collective_types , bandwidth_summary = process_communication_data (steady_state )
192+ print_bandwidth_table (collective_types , bandwidth_summary )
193+
194+ hidden_to_total_collective_types , hidden_to_total_summary = (
195+ process_hidden_ms_to_total_ms (steady_state )
196+ )
197+ if hidden_to_total_summary is not None :
198+ overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms (
199+ steady_state
200+ )
201+ print_hidden_ms_to_total_ms_table (
202+ hidden_to_total_collective_types ,
203+ hidden_to_total_summary ,
204+ overall_hidden_ms_to_total_ms ,
205+ )
206+
207+ # Write all tables to a single CSV file
208+ write_to_csv (
209+ collective_types ,
210+ bandwidth_summary ,
211+ hidden_to_total_summary ,
212+ overall_hidden_ms_to_total_ms ,
213+ "communication_summary.csv" ,
214+ )
215+
216+
99217if __name__ == "__main__" :
100218 main ()
0 commit comments