11#!/usr/bin/env python
22import argparse
33import csv
4+ import pathlib
45from collections import defaultdict
6+ from math import sqrt
57
68from nsys_jax import (
79 align_profiler_data_timestamps ,
810 apply_warmup_heuristics ,
911 ensure_compiled_protos_are_importable ,
1012 load_profiler_data ,
1113)
12- from math import sqrt
1314from prettytable import PrettyTable
14- import pathlib
1515from uncertainties import ufloat # type: ignore
1616
1717
1818def process_communication_data (steady_state ):
19+ """
20+ Process communication data from a steady state, to compute bandwith summaries.
21+
22+ Args:
23+ steady_state: A steady state data frame.
24+
25+ Return:
26+ A tuple of (collective_types, summary_data), where:
27+ collective_types (List[str]): sorted list of collective operation types
28+ summary_data (Dict[int, Dict[str, ufloat]]): Dictionary wiht summaries for bandwith data
29+ """
1930 collective_types = set ()
2031 summary_data = defaultdict (dict )
21- for (collective , message_size ), df in steady_state .communication .groupby (
32+
33+ communication_grouped_by = steady_state .communication .groupby (
2234 ["Collective" , "MessageSize" ]
23- ):
35+ )
36+
37+ for (collective , message_size ), df in communication_grouped_by :
2438 collective_types .add (collective )
2539 # This grouped data frame will have a row for each device that is participating
2640 # in this instance of the collective.
@@ -31,14 +45,43 @@ def process_communication_data(steady_state):
3145 # therefore the last one to join the collective and its bandwidth estimate does
3246 # not contain a wait time component. The .mean() is over the different
3347 # (ProgramId, ProgramExecution, ThunkIndex) values.
34- bandwidth = devices ["BusBandwidthGBPerSec" ].agg ("max" )
48+ bandwidth_of_fastest_device = devices ["BusBandwidthGBPerSec" ].agg ("max" )
49+ mean_bandwidth = bandwidth_of_fastest_device .mean ()
50+ stderr_bandwidth = bandwidth_of_fastest_device .std () / sqrt (
51+ len (bandwidth_of_fastest_device )
52+ )
53+
3554 summary_data [message_size ][collective ] = ufloat (
36- bandwidth . mean (), bandwidth . std () / sqrt ( len ( bandwidth ))
55+ mean_bandwidth , stderr_bandwidth
3756 )
57+
3858 return sorted (collective_types ), summary_data
3959
4060
4161def print_bandwidth_table (collective_types , summary_data ):
62+ """
63+ This function prints a table for summarizing the bandwidth for each collective operation
64+
65+ Args:
66+ collective_types (List[str]): sorted list of collective operation types
67+ summary_data (Dict[int, Dict[str, ufloat]]): Dictionary wiht summaries for bandwith data
68+ """
69+
70+ def format_message_size (message_size ):
71+ """
72+ Function to format the message size
73+ """
74+ return f"{ message_size :<{size_width },} "
75+
76+ def format_bandwidth (data , collective ):
77+ """
78+ Function to format the bandwidth
79+ """
80+ width = collective_widths [collective ]
81+ if collective not in data :
82+ return "-" * width
83+ return f"{ data [collective ]:>{width }S} "
84+
4285 collective_widths = {
4386 collective : max (
4487 len (collective ),
@@ -52,54 +95,62 @@ def print_bandwidth_table(collective_types, summary_data):
5295 }
5396 size_heading = "Size [B]"
5497 size_width = max (len (size_heading ), max (len (f"{ s :,} " ) for s in summary_data .keys ()))
55- print (f"{ '' :<{size_width }} | Bus bandwidth [GB/s]" )
56- print (
57- " | " .join (
58- [f"{ size_heading :<{size_width }} " ]
59- + [f"{ coll :<{collective_widths [coll ]}} " for coll in collective_types ]
60- )
61- )
62-
63- def format_message_size (message_size ):
64- return f"{ message_size :<{size_width },} "
6598
66- def format_bandwidth (data , collective ):
67- width = collective_widths [collective ]
68- if collective not in data :
69- return "-" * width
70- return f"{ data [collective ]:>{width }S} "
99+ header_log = f"{ '' :<{size_width }} | Bus bandwidth [GB/s]"
100+ print (header_log )
101+ log_specs = " | " .join (
102+ [f"{ size_heading :<{size_width }} " ]
103+ + [f"{ coll :<{collective_widths [coll ]}} " for coll in collective_types ]
104+ )
105+ print (log_specs )
71106
72107 for message_size in sorted (summary_data .keys ()):
73108 data = summary_data [message_size ]
74- print (
75- " | " .join (
76- [format_message_size (message_size )]
77- + [
78- format_bandwidth (data , collective )
79- for collective in collective_types
80- ]
81- )
109+ log_row = " | " .join (
110+ [format_message_size (message_size )]
111+ + [format_bandwidth (data , collective ) for collective in collective_types ]
82112 )
113+ print (log_row )
83114
84115
85116def process_hidden_ms_to_total_ms (steady_state ):
117+ """
118+ Function to compute the fraction of communication time that is hidden behind computations.
119+
120+ Args:
121+ steady_state: The steady state data
122+
123+ Returns:
124+ collective_types (Set[str]): set of collective operation types
125+ summary_data (Dict[str, float]): dictionary with mean hidden-to-total milliseconds ratio
126+ """
86127 if steady_state .communication ["ProjDurHiddenMs" ].sum () == 0 :
87128 return None , None
88129
89130 collective_types = set ()
90131 summary_data = defaultdict (dict )
91- for collective , df in steady_state .communication .groupby (["Collective" ]):
132+ grouped_data = steady_state .communication .groupby (["Collective" ])
133+
134+ for collective , df in grouped_data :
92135 collective_types .add (collective )
93- mean_dur_hidden_ms_to_total_ms = (
94- df ["ProjDurHiddenMs" ] / (df ["ProjDurMs" ] + df ["ProjDurHiddenMs" ])
95- ).mean ()
136+ total_ms = df ["ProjDurMs" ] + df ["ProjDurHiddenMs" ]
137+ mean_dur_hidden_ms_to_total_ms = (df ["ProjDurHiddenMs" ] / total_ms ).mean ()
96138 summary_data [collective ] = mean_dur_hidden_ms_to_total_ms
139+
97140 return collective_types , summary_data
98141
99142
100143def print_hidden_ms_to_total_ms_table (
101144 collective_types , summary_data , overall_hidden_ms_to_total_ms
102145):
146+ """
147+ Print the hidden ms to total ms
148+
149+ Args:
150+ collective_types (Set[str]): set of collective operation types
151+ summary_data (Dict[str, float]): mean hidden-to-total milliseconds ratio
152+ overall_ratio (float): overall hidden-to-total milliseconds ratio
153+ """
103154 table = PrettyTable ()
104155 table .field_names = ["Collective" , "Mean HiddenToTotalMs" ]
105156
@@ -108,20 +159,31 @@ def print_hidden_ms_to_total_ms_table(
108159 table .add_row ([collective [0 ], mean_value ])
109160
110161 print (table )
111- print ("Overall HiddenMs to TotalMs:" , overall_hidden_ms_to_total_ms )
162+ if overall_hidden_ms_to_total_ms is not None :
163+ print (f"Overall HiddenMs to TotalMs: { overall_hidden_ms_to_total_ms :.4f} " )
112164
113165
114166def calculate_overall_hidden_ms_to_total_ms (steady_state ):
115- if steady_state .communication ["ProjDurHiddenMs" ].sum () == 0 :
167+ """
168+ Function to calculate the overall hidden milliseconds to total milliseconds
169+
170+ Args:
171+ steady_state: the steady-state data extracted from the profiler
172+
173+ Returns:
174+ overall_hidden_ms_to_total_ms (float): overall hidden milliseconds to total milliseconds ratio
175+ """
176+ total_hidden_ms = steady_state .communication ["ProjDurHiddenMs" ].sum ()
177+ if total_hidden_ms == 0 :
116178 return None
117179
118- overall_hidden_ms_to_total_ms = (
119- steady_state .communication ["ProjDurHiddenMs" ]. sum ()
120- / (
121- steady_state . communication [ "ProjDurMs" ]
122- + steady_state . communication [ "ProjDurHiddenMs" ]
123- ). sum ()
124- )
180+ total_ms = (
181+ steady_state .communication ["ProjDurMs" ]
182+ + steady_state . communication [ "ProjDurHiddenMs" ]
183+ ). sum ()
184+
185+ overall_hidden_ms_to_total_ms = total_hidden_ms / total_ms
186+
125187 return overall_hidden_ms_to_total_ms
126188
127189
@@ -132,6 +194,17 @@ def write_to_csv(
132194 overall_hidden_ms_to_total_ms ,
133195 output_file ,
134196):
197+ """
198+ Function to write the summaries to a csv file
199+
200+ Args:
201+ collective_types (List[str]): list of collective operation types
202+ bandwidth_summary (Dict[int, Dict[str, ufloat]]): bandwidth summary data
203+ hidden_to_total_summary (Dict[str, float]): hidden-to-total milliseconds ratio summary
204+ overall_hidden_ms_to_total_ms (float): overall hidden-to-total milliseconds ratio
205+ output_file (str): output CSV file path
206+
207+ """
135208 with open (output_file , "w" , newline = "" ) as csvfile :
136209 writer = csv .writer (csvfile )
137210
@@ -165,6 +238,9 @@ def write_to_csv(
165238
166239
167240def main ():
241+ """
242+ Main entry point to process the nsys-jax report and generate communication summaries
243+ """
168244 parser = argparse .ArgumentParser (
169245 description = "Summarise communication in an nsys-jax report"
170246 )
@@ -178,22 +254,26 @@ def main():
178254 all_data = load_profiler_data (args .prefix , frames = {"communication" , "compile" })
179255 # Align timestamps
180256 all_data , alignment_metadata = align_profiler_data_timestamps (all_data )
181- # TODO: make this pretty
182- # print(alignment_metadata)
257+ print (f"Alignment metadata: { alignment_metadata } " )
183258 # Partition the profile data into initialisation and steady-state running
184259 _ , steady_state = apply_warmup_heuristics (all_data )
185260
186- assert len (steady_state .communication ), (
187- "Communication summary was requested but no steady-state communication was "
188- "identified."
189- )
261+ if len (steady_state .communication ) == 0 :
262+ print (
263+ "Communication summary was requested but no steady-state communication was identified."
264+ )
265+ return
190266
191267 collective_types , bandwidth_summary = process_communication_data (steady_state )
192268 print_bandwidth_table (collective_types , bandwidth_summary )
193269
194270 hidden_to_total_collective_types , hidden_to_total_summary = (
195271 process_hidden_ms_to_total_ms (steady_state )
196272 )
273+
274+ # initailise overall_hidden_ms_to_total_ms
275+ overall_hidden_ms_to_total_ms = None
276+
197277 if hidden_to_total_summary is not None :
198278 overall_hidden_ms_to_total_ms = calculate_overall_hidden_ms_to_total_ms (
199279 steady_state
0 commit comments