3838)
3939import numpy as np
4040import os
41+ from DGraph .utils .TimingReport import TimingReport
42+ import json
4143
4244
4345class SingleProcessDummyCommunicator (CommunicatorBase ):
@@ -81,6 +83,10 @@ def rank_cuda_device(self):
8183 device = torch .cuda .current_device ()
8284 return device
8385
86+ def barrier (self ):
87+ # No-op for single process
88+ pass
89+
8490
8591def _run_experiment (
8692 dataset ,
@@ -98,7 +104,7 @@ def _run_experiment(
98104 torch .cuda .set_device (local_rank )
99105 device = torch .cuda .current_device ()
100106 model = GCN (
101- in_channels = 100 , hidden_dims = hidden_dims , num_classes = num_classes , comm = comm
107+ in_channels = 128 , hidden_dims = hidden_dims , num_classes = num_classes , comm = comm
102108 )
103109 rank = comm .get_rank ()
104110 model = model .to (device )
@@ -126,7 +132,7 @@ def _run_experiment(
126132 print (f"Rank: { rank } Mapping: { rank_mappings .shape } " )
127133 print (f"Rank: { rank } Node Features: { node_features .shape } " )
128134 print (f"Rank: { rank } Edge Indices: { edge_indices .shape } " )
129- dist .barrier ()
135+ comm .barrier ()
130136 criterion = torch .nn .CrossEntropyLoss ()
131137
132138 train_mask = dataset .graph_obj .get_local_mask ("train" , rank )
@@ -152,14 +158,13 @@ def _run_experiment(
152158 scatter_cache_file = f"{ cache_prefix } _scatter_cache_{ world_size } _{ rank } .pt"
153159 gather_cache_file = f"{ cache_prefix } _gather_cache_{ world_size } _{ rank } .pt"
154160
161+ # if os.path.exists(scatter_cache_file):
162+ # print(f"Rank: {rank} Loading scatter cache from {scatter_cache_file}")
163+ # scatter_cache = torch.load(scatter_cache_file, weights_only=False)
164+ # else:
165+ # print(f"Rank: {rank} Scatter cache not found, generating new cache")
166+ # print(f"Rank: {rank} Cache file: {scatter_cache_file}")
155167
156- if os .path .exists (scatter_cache_file ):
157- print (f"Rank: { rank } Loading scatter cache from { scatter_cache_file } " )
158- scatter_cache = torch .load (scatter_cache_file , weights_only = False )
159- else :
160- print (f"Rank: { rank } Scatter cache not found, generating new cache" )
161- print (f"Rank: { rank } Cache file: { scatter_cache_file } " )
162-
163168 if os .path .exists (gather_cache_file ):
164169 print (f"Rank: { rank } Loading gather cache from { gather_cache_file } " )
165170 gather_cache = torch .load (gather_cache_file , weights_only = False )
@@ -227,7 +232,9 @@ def _run_experiment(
227232 assert rank != rank
228233 assert value .shape [0 ] == scatter_cache .gather_recv_comm_vector
229234 end_time = perf_counter ()
230- print (f"Rank: { rank } Cache Generation Time: { end_time - start_time :.4f} s" )
235+ elapsed_time_in_ms = (end_time - start_time ) * 1000
236+ print (f"Rank: { rank } Cache Generation Time: { elapsed_time_in_ms :.4f} ms" )
237+ TimingReport .add_time ("cache_generation_time" , elapsed_time_in_ms )
231238
232239 with open (f"{ log_prefix } _gather_cache_{ world_size } _{ rank } .pt" , "wb" ) as f :
233240 torch .save (gather_cache , f )
@@ -237,7 +244,7 @@ def _run_experiment(
237244
238245 training_times = []
239246 for i in range (epochs ):
240- dist .barrier ()
247+ comm .barrier ()
241248 torch .cuda .synchronize ()
242249 start_time = torch .cuda .Event (enable_timing = True )
243250 end_time = torch .cuda .Event (enable_timing = True )
@@ -254,7 +261,7 @@ def _run_experiment(
254261 dist_print_ephemeral (f"Epoch { i } \t Loss: { loss .item ()} " , rank )
255262 optimizer .step ()
256263
257- dist .barrier ()
264+ comm .barrier ()
258265 end_time .record (stream )
259266 torch .cuda .synchronize ()
260267 training_times .append (start_time .elapsed_time (end_time ))
@@ -362,6 +369,7 @@ def main(
362369 node_rank_placement_file , weights_only = False
363370 )
364371
372+ TimingReport .init (comm )
365373 safe_create_dir (log_dir , comm .get_rank ())
366374 training_dataset = DistributedOGBWrapper (
367375 f"ogbn-{ dataset } " ,
@@ -377,7 +385,7 @@ def main(
377385 validation_accuracies = np .zeros ((runs , epochs ))
378386 world_size = comm .get_world_size ()
379387
380- dist .barrier ()
388+ comm .barrier ()
381389 print (f"Running experiment with { world_size } processes on dataset { dataset } " )
382390 print (f"Using cache: { use_cache } " )
383391
@@ -397,6 +405,11 @@ def main(
397405 validation_trajectores [i ] = val_traj
398406 validation_accuracies [i ] = val_accuracy
399407
408+ write_experiment_log (
409+ json .dumps (TimingReport ._timers ),
410+ f"{ log_dir } /timing_report_world_size_{ world_size } _cache_{ use_cache } .json" ,
411+ comm .get_rank (),
412+ )
400413 visualize_trajectories (
401414 training_trajectores ,
402415 "Training Loss" ,
0 commit comments