Skip to content

Commit d4c7306

Browse files
committed
Add barrier around dataset processor for race condition
1 parent cce5c0e commit d4c7306

File tree

3 files changed

+35
-3
lines changed

3 files changed

+35
-3
lines changed

DGraph/distributed/nccl/NCCLBackendEngine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,11 @@ def __init__(self, ranks_per_graph=-1, *args, **kwargs):
512512
if not NCCLBackendEngine._is_initialized:
513513
self.init_process_group(ranks_per_graph)
514514

515+
def barrier(self) -> None:
516+
if not dist.is_initialized():
517+
raise RuntimeError("NCCL backend engine is not initialized")
518+
dist.barrier()
519+
515520
def init_process_group(self, ranks_per_graph=-1, *args, **kwargs):
516521
if not dist.is_initialized():
517522
dist.init_process_group(backend="nccl", *args, **kwargs)

experiments/OGB/GCN.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
import torch.distributed as dist
17+
from DGraph.utils.TimingReport import TimingReport
1718

1819

1920
class ConvLayer(nn.Module):
@@ -54,24 +55,41 @@ def forward(
5455
num_local_nodes = node_features.size(1)
5556
_src_indices = edge_index[:, 0, :]
5657
_dst_indices = edge_index[:, 1, :]
58+
TimingReport.start("pre-processing")
5759
_src_rank_mappings = torch.cat(
5860
[rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0
5961
)
6062
_dst_rank_mappings = torch.cat(
6163
[rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0
6264
)
65+
TimingReport.stop("pre-processing")
66+
TimingReport.start("Gather_1")
6367
x = self.comm.gather(
6468
node_features, _dst_indices, _dst_rank_mappings, cache=gather_cache
6569
)
70+
TimingReport.stop("Gather_1")
71+
TimingReport.start("Conv_1")
6672
x = self.conv1(x)
73+
TimingReport.stop("Conv_1")
74+
TimingReport.start("Scatter_1")
6775
x = self.comm.scatter(
6876
x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache
6977
)
78+
TimingReport.stop("Scatter_1")
79+
TimingReport.start("Gather_2")
7080
x = self.comm.gather(x, _dst_indices, _dst_rank_mappings, cache=gather_cache)
81+
TimingReport.stop("Gather_2")
82+
TimingReport.start("Conv_2")
7183
x = self.conv2(x)
84+
TimingReport.stop("Conv_2")
85+
TimingReport.start("Scatter_2")
7286
x = self.comm.scatter(
7387
x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache
7488
)
89+
TimingReport.stop("Scatter_2")
90+
TimingReport.start("Final_FC")
7591
x = self.fc(x)
92+
TimingReport.stop("Final_FC")
93+
7694
# x = self.softmax(x)
7795
return x

experiments/OGB/main.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
)
3939
import numpy as np
4040
import os
41+
from DGraph.utils.TimingReport import TimingReport
42+
import json
4143

4244

4345
class SingleProcessDummyCommunicator(CommunicatorBase):
@@ -131,7 +133,6 @@ def _run_experiment(
131133
print(f"Rank: {rank} Mapping: {rank_mappings.shape}")
132134
print(f"Rank: {rank} Node Features: {node_features.shape}")
133135
print(f"Rank: {rank} Edge Indices: {edge_indices.shape}")
134-
135136
comm.barrier()
136137
criterion = torch.nn.CrossEntropyLoss()
137138

@@ -229,7 +230,9 @@ def _run_experiment(
229230
assert rank != rank
230231
assert value.shape[0] == scatter_cache.gather_recv_comm_vector
231232
end_time = perf_counter()
232-
print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s")
233+
elapsed_time_in_ms = (end_time - start_time) * 1000
234+
print(f"Rank: {rank} Cache Generation Time: {elapsed_time_in_ms:.4f} ms")
235+
TimingReport.add_time("cache_generation_time", elapsed_time_in_ms)
233236

234237
# with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
235238
# torch.save(gather_cache, f)
@@ -366,6 +369,7 @@ def main(
366369
node_rank_placement_file, weights_only=False
367370
)
368371

372+
TimingReport.init(comm)
369373
safe_create_dir(log_dir, comm.get_rank())
370374
training_dataset = DistributedOGBWrapper(
371375
f"ogbn-{dataset}",
@@ -381,7 +385,7 @@ def main(
381385
validation_accuracies = np.zeros((runs, epochs))
382386
world_size = comm.get_world_size()
383387

384-
dist.barrier()
388+
comm.barrier()
385389
print(f"Running experiment with {world_size} processes on dataset {dataset}")
386390
print(f"Using cache: {use_cache}")
387391

@@ -402,6 +406,11 @@ def main(
402406
validation_trajectores[i] = val_traj
403407
validation_accuracies[i] = val_accuracy
404408

409+
write_experiment_log(
410+
json.dumps(TimingReport._timers),
411+
f"{log_dir}/timing_report_world_size_{world_size}_cache_{use_cache}.json",
412+
comm.get_rank(),
413+
)
405414
visualize_trajectories(
406415
training_trajectores,
407416
"Training Loss",

0 commit comments

Comments
 (0)