Skip to content

Commit b3b4e9b

Browse files
committed
Add barrier around dataset processor for race condition
1 parent 29a34ed commit b3b4e9b

File tree

5 files changed

+66
-16
lines changed

5 files changed

+66
-16
lines changed

DGraph/CommunicatorBase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,6 @@ def get_rank(self) -> int:
2626

2727
def get_world_size(self) -> int:
2828
raise NotImplementedError
29+
30+
def barrier(self):
31+
raise NotImplementedError

DGraph/data/ogbn_datasets.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,20 @@ def __init__(
174174
self._rank = self.comm_object.get_rank()
175175
self._world_size = self.comm_object.get_world_size()
176176

177-
self.dataset = NodePropPredDataset(
178-
name=dname,
179-
)
177+
comm_object.barrier()
178+
if comm_object.get_rank() == 0:
179+
self.dataset = NodePropPredDataset(
180+
name=dname,
181+
)
182+
# Block until the dataset is loaded on rank 0
183+
comm_object.barrier()
184+
# Load the dataset on all other ranks, but this should use the
185+
# processed data on rank 0
186+
if comm_object.get_rank() != 0:
187+
self.dataset = NodePropPredDataset(
188+
name=dname,
189+
)
190+
180191
graph_data, labels = self.dataset[0]
181192

182193
self.split_idx = self.dataset.get_idx_split()

DGraph/distributed/nccl/NCCLBackendEngine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,11 @@ def __init__(self, ranks_per_graph=-1, *args, **kwargs):
512512
if not NCCLBackendEngine._is_initialized:
513513
self.init_process_group(ranks_per_graph)
514514

515+
def barrier(self) -> None:
516+
if not dist.is_initialized():
517+
raise RuntimeError("NCCL backend engine is not initialized")
518+
dist.barrier()
519+
515520
def init_process_group(self, ranks_per_graph=-1, *args, **kwargs):
516521
if not dist.is_initialized():
517522
dist.init_process_group(backend="nccl", *args, **kwargs)

experiments/OGB/GCN.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import torch
1515
import torch.nn as nn
1616
import torch.distributed as dist
17+
from DGraph.utils.TimingReport import TimingReport
1718

1819

1920
class ConvLayer(nn.Module):
@@ -54,24 +55,41 @@ def forward(
5455
num_local_nodes = node_features.size(1)
5556
_src_indices = edge_index[:, 0, :]
5657
_dst_indices = edge_index[:, 1, :]
58+
TimingReport.start("pre-processing")
5759
_src_rank_mappings = torch.cat(
5860
[rank_mapping[0].unsqueeze(0), rank_mapping[0].unsqueeze(0)], dim=0
5961
)
6062
_dst_rank_mappings = torch.cat(
6163
[rank_mapping[0].unsqueeze(0), rank_mapping[1].unsqueeze(0)], dim=0
6264
)
65+
TimingReport.stop("pre-processing")
66+
TimingReport.start("Gather_1")
6367
x = self.comm.gather(
6468
node_features, _dst_indices, _dst_rank_mappings, cache=gather_cache
6569
)
70+
TimingReport.stop("Gather_1")
71+
TimingReport.start("Conv_1")
6672
x = self.conv1(x)
73+
TimingReport.stop("Conv_1")
74+
TimingReport.start("Scatter_1")
6775
x = self.comm.scatter(
6876
x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache
6977
)
78+
TimingReport.stop("Scatter_1")
79+
TimingReport.start("Gather_2")
7080
x = self.comm.gather(x, _dst_indices, _dst_rank_mappings, cache=gather_cache)
81+
TimingReport.stop("Gather_2")
82+
TimingReport.start("Conv_2")
7183
x = self.conv2(x)
84+
TimingReport.stop("Conv_2")
85+
TimingReport.start("Scatter_2")
7286
x = self.comm.scatter(
7387
x, _src_indices, _src_rank_mappings, num_local_nodes, cache=scatter_cache
7488
)
89+
TimingReport.stop("Scatter_2")
90+
TimingReport.start("Final_FC")
7591
x = self.fc(x)
92+
TimingReport.stop("Final_FC")
93+
7694
# x = self.softmax(x)
7795
return x

experiments/OGB/main.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
)
3939
import numpy as np
4040
import os
41+
from DGraph.utils.TimingReport import TimingReport
42+
import json
4143

4244

4345
class SingleProcessDummyCommunicator(CommunicatorBase):
@@ -81,6 +83,10 @@ def rank_cuda_device(self):
8183
device = torch.cuda.current_device()
8284
return device
8385

86+
def barrier(self):
87+
# No-op for single process
88+
pass
89+
8490

8591
def _run_experiment(
8692
dataset,
@@ -98,7 +104,7 @@ def _run_experiment(
98104
torch.cuda.set_device(local_rank)
99105
device = torch.cuda.current_device()
100106
model = GCN(
101-
in_channels=100, hidden_dims=hidden_dims, num_classes=num_classes, comm=comm
107+
in_channels=128, hidden_dims=hidden_dims, num_classes=num_classes, comm=comm
102108
)
103109
rank = comm.get_rank()
104110
model = model.to(device)
@@ -126,7 +132,7 @@ def _run_experiment(
126132
print(f"Rank: {rank} Mapping: {rank_mappings.shape}")
127133
print(f"Rank: {rank} Node Features: {node_features.shape}")
128134
print(f"Rank: {rank} Edge Indices: {edge_indices.shape}")
129-
dist.barrier()
135+
comm.barrier()
130136
criterion = torch.nn.CrossEntropyLoss()
131137

132138
train_mask = dataset.graph_obj.get_local_mask("train", rank)
@@ -152,14 +158,13 @@ def _run_experiment(
152158
scatter_cache_file = f"{cache_prefix}_scatter_cache_{world_size}_{rank}.pt"
153159
gather_cache_file = f"{cache_prefix}_gather_cache_{world_size}_{rank}.pt"
154160

161+
# if os.path.exists(scatter_cache_file):
162+
# print(f"Rank: {rank} Loading scatter cache from {scatter_cache_file}")
163+
# scatter_cache = torch.load(scatter_cache_file, weights_only=False)
164+
# else:
165+
# print(f"Rank: {rank} Scatter cache not found, generating new cache")
166+
# print(f"Rank: {rank} Cache file: {scatter_cache_file}")
155167

156-
if os.path.exists(scatter_cache_file):
157-
print(f"Rank: {rank} Loading scatter cache from {scatter_cache_file}")
158-
scatter_cache = torch.load(scatter_cache_file, weights_only=False)
159-
else:
160-
print(f"Rank: {rank} Scatter cache not found, generating new cache")
161-
print(f"Rank: {rank} Cache file: {scatter_cache_file}")
162-
163168
if os.path.exists(gather_cache_file):
164169
print(f"Rank: {rank} Loading gather cache from {gather_cache_file}")
165170
gather_cache = torch.load(gather_cache_file, weights_only=False)
@@ -227,7 +232,9 @@ def _run_experiment(
227232
assert rank != rank
228233
assert value.shape[0] == scatter_cache.gather_recv_comm_vector
229234
end_time = perf_counter()
230-
print(f"Rank: {rank} Cache Generation Time: {end_time - start_time:.4f} s")
235+
elapsed_time_in_ms = (end_time - start_time) * 1000
236+
print(f"Rank: {rank} Cache Generation Time: {elapsed_time_in_ms:.4f} ms")
237+
TimingReport.add_time("cache_generation_time", elapsed_time_in_ms)
231238

232239
with open(f"{log_prefix}_gather_cache_{world_size}_{rank}.pt", "wb") as f:
233240
torch.save(gather_cache, f)
@@ -237,7 +244,7 @@ def _run_experiment(
237244

238245
training_times = []
239246
for i in range(epochs):
240-
dist.barrier()
247+
comm.barrier()
241248
torch.cuda.synchronize()
242249
start_time = torch.cuda.Event(enable_timing=True)
243250
end_time = torch.cuda.Event(enable_timing=True)
@@ -254,7 +261,7 @@ def _run_experiment(
254261
dist_print_ephemeral(f"Epoch {i} \t Loss: {loss.item()}", rank)
255262
optimizer.step()
256263

257-
dist.barrier()
264+
comm.barrier()
258265
end_time.record(stream)
259266
torch.cuda.synchronize()
260267
training_times.append(start_time.elapsed_time(end_time))
@@ -362,6 +369,7 @@ def main(
362369
node_rank_placement_file, weights_only=False
363370
)
364371

372+
TimingReport.init(comm)
365373
safe_create_dir(log_dir, comm.get_rank())
366374
training_dataset = DistributedOGBWrapper(
367375
f"ogbn-{dataset}",
@@ -377,7 +385,7 @@ def main(
377385
validation_accuracies = np.zeros((runs, epochs))
378386
world_size = comm.get_world_size()
379387

380-
dist.barrier()
388+
comm.barrier()
381389
print(f"Running experiment with {world_size} processes on dataset {dataset}")
382390
print(f"Using cache: {use_cache}")
383391

@@ -397,6 +405,11 @@ def main(
397405
validation_trajectores[i] = val_traj
398406
validation_accuracies[i] = val_accuracy
399407

408+
write_experiment_log(
409+
json.dumps(TimingReport._timers),
410+
f"{log_dir}/timing_report_world_size_{world_size}_cache_{use_cache}.json",
411+
comm.get_rank(),
412+
)
400413
visualize_trajectories(
401414
training_trajectores,
402415
"Training Loss",

0 commit comments

Comments
 (0)