Skip to content

Commit ef56349

Browse files
committed
Commit for all files used for SC submission
1 parent 6642eac commit ef56349

File tree

8 files changed

+949
-169
lines changed

8 files changed

+949
-169
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch.nn as nn
2+
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
3+
from torch_geometric.utils import to_dense_batch
4+
from torch_geometric.nn import NNConv
5+
6+
7+
class LSC_Trainer(nn.Module):
8+
def __init__(self, num_nodes):
9+
super(LSC_Trainer, self).__init__()
10+
self.bond_encoder = BondEncoder(16)
11+
self.atom_encoder = AtomEncoder(64)
12+
self._graph_nn= nn.Sequential(nn.Linear(16, 64, bias=False),
13+
nn.ReLU(),
14+
nn.Linear(64, 32, bias=False),
15+
nn.ReLU(),
16+
nn.Linear(32, 64*32, bias=False))
17+
18+
self.graph_conv = NNConv(64, 32, self._graph_nn)
19+
self.num_nodes = num_nodes
20+
self._nn = nn.Sequential(nn.Linear(num_nodes*32, 256),
21+
nn.ReLU(),
22+
nn.Linear(256, 128),
23+
nn.ReLU(),
24+
nn.Linear(128, 32),
25+
nn.ReLU(),
26+
nn.Linear(32, 8),
27+
nn.ReLU(),
28+
nn.Linear(8, 1))
29+
30+
def flatten(self, x):
31+
return x.view(x.size(0), -1)
32+
33+
def forward(self, data):
34+
node_features = data.x
35+
edge_features = data.edge_attr
36+
edge_index = data.edge_index
37+
batch = data.batch
38+
39+
encoded_atoms = self.atom_encoder(node_features)
40+
encoded_bonds = self.bond_encoder(edge_features)
41+
42+
updated_features = self.graph_conv(encoded_atoms, edge_index, encoded_bonds)
43+
44+
updated_features = self.flatten(to_dense_batch(updated_features, batch, max_num_nodes=self.num_nodes)[0])
45+
46+
out = self._nn(updated_features)
47+
48+
return out
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from torch_geometric.utils import erdos_renyi_graph
2+
from torch_geometric.data import Data
3+
from tqdm import tqdm
4+
import torch
5+
import pickle
6+
import multiprocessing as mp
7+
import numpy as np
8+
9+
node_sizes = [64, 128]
10+
p_vals = np.logspace(-2.5,-.4, num=10)
11+
12+
_f_name = "/p/vast1/zaman2/synth_data/{}_{}_Pytorch.pickle"
13+
14+
_num_samples = 10000
15+
16+
def make_dataset(_n, _p):
17+
18+
_dataset = []
19+
max_eddges = 0
20+
edge_spread = []
21+
count = 0
22+
while(count < _num_samples):
23+
24+
edge_indices = erdos_renyi_graph(_n, _p)
25+
26+
if (edge_indices.shape[1] > 1):
27+
node_features = torch.randint(1,(_n, 9), dtype=torch.int)
28+
edge_features = torch.randint(1,(edge_indices.shape[1], 3), dtype=torch.int)
29+
30+
target = torch.rand(1,1)
31+
data = Data(x=node_features,
32+
edge_index = edge_indices,
33+
edge_attr=edge_features,
34+
y = target)
35+
_dataset.append(data)
36+
edge_spread.append(edge_indices.shape)
37+
if edge_indices.shape[1] > max_eddges:
38+
max_eddges = edge_indices.shape[1]
39+
count += 1
40+
with open(_f_name.format(_n,max_eddges), 'wb') as f:
41+
pickle.dump(_dataset, f)
42+
print(max_eddges)
43+
44+
edge_spread = np.array(edge_spread)
45+
46+
np.save(f'{_n}_{max_eddges}.npy', edge_spread)
47+
return(max_eddges)
48+
49+
50+
51+
combos = []
52+
for _n in node_sizes:
53+
for _p in p_vals:
54+
55+
make_dataset(_n,_p)
56+
print(_n, _p)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
import glob
3+
import pickle
4+
5+
_files = glob.glob("/p/vast1/zaman2/synth_data/*.pickle")
6+
7+
#print(_files)
8+
9+
10+
_vals = {}
11+
12+
for _file in _files:
13+
_temp = _file.split("/")[-1].split(".")[0].split("_")
14+
_num_nodes = _temp[0]
15+
_num_edges = _temp[1]
16+
print(_num_nodes, _num_edges)
17+
18+
_edge_dic = {}
19+
with open(_file, 'rb') as f:
20+
_data = pickle.load(f)
21+
22+
edge_sizes = []
23+
for obj in _data:
24+
edge_sizes.append(obj.num_edges)
25+
_edge_dic[_num_edges] = edge_sizes
26+
_vals[_num_nodes] = _edge_dic
27+
28+
with open("_gen_stats.pickle", 'wb') as f:
29+
pickle.dump(_vals, f)

applications/graph/GNN/PyTorch_Implementation/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def main(BATCH_SIZE, dist=False, sync=True):
9898

9999
print("Writing log information to ", file_name, flush=True)
100100

101-
for epoch in range(1, 3):
101+
for epoch in range(0, 5):
102102

103103
epoch_loss = 0
104104
epoch_start_time = time.perf_counter()
@@ -127,7 +127,7 @@ def main(BATCH_SIZE, dist=False, sync=True):
127127

128128
if rank == 0:
129129
batch_times.update(time.perf_counter() - _time_start)
130-
print("Mini Batch Times ", i,": \t", batch_times.mean(), "LOSS: \t", loss_tracker.mean(), flush=True)
130+
#print("Mini Batch Times ", i,": \t", batch_times.mean(), "LOSS: \t", loss_tracker.mean(), flush=True)
131131
if dist:
132132
torch.distributed.barrier()
133133

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
from utils import get_world_size, init_dist, AverageTracker, get_local_rank
2+
import torch
3+
import argparse
4+
import time
5+
import pickle
6+
from torch_geometric.data import DataLoader
7+
8+
9+
from SYNTH_Trainer import LSC_Trainer
10+
11+
import glob
12+
desc = "PyTorch Geometric Distributed Trainer for OGB LSC dataset"
13+
14+
parser = argparse.ArgumentParser(description=desc)
15+
16+
parser.add_argument(
17+
'--mini-batch-size', action='store', default=2048, type=int,
18+
help='mini-batch size (default: 512)', metavar='NUM')
19+
20+
parser.add_argument(
21+
'--num-nodes', action='store', default=16, type=int,
22+
help='default 16', metavar='NUM')
23+
24+
parser.add_argument(
25+
'--num-edges', action='store', default=16, type=int,
26+
help='default 16', metavar='NUM')
27+
28+
29+
parser.add_argument('--no-sync', dest='mini_batch_sync', action='store_false')
30+
31+
parser.add_argument('--sync', dest='mini_batch_sync', action='store_true')
32+
33+
34+
parser.add_argument('--dist', dest='dist', action='store_true')
35+
36+
parser.set_defaults(feature=True)
37+
args = parser.parse_args()
38+
39+
mb_size = args.mini_batch_size
40+
41+
num_nodes = args.num_nodes
42+
num_edges = args.num_edges
43+
sync = args.mini_batch_sync
44+
45+
distributed_training = args.dist
46+
47+
48+
49+
def main(BATCH_SIZE, dist=False, sync=True):
50+
time_stamp = time.strftime("%d-%m-%Y-%H-%M-%S", time.gmtime())
51+
52+
if dist:
53+
init_dist("/p/vast1/zaman2/randevous_files_"+str(BATCH_SIZE))
54+
rank = torch.distributed.get_rank()
55+
56+
else:
57+
rank = 0
58+
59+
primary = rank == 0
60+
world_size = get_world_size()
61+
62+
if primary:
63+
print("Running distributed: ", dist, "\t world size: ", world_size)
64+
65+
_files = [f"/p/vast1/zaman2/synth_data/{num_nodes}_{num_edges}_Pytorch.pickle"]
66+
#_files = glob.glob(_files_str)
67+
for _file in _files:
68+
69+
#num_edges = _file.split("/")[-1].split(".")[0].split("_")[1]
70+
print(num_edges)
71+
with open(_file,'rb') as f:
72+
train_dataset = pickle.load(f)
73+
74+
75+
76+
train_loader = DataLoader(train_dataset,
77+
batch_size=(BATCH_SIZE),
78+
pin_memory=True)
79+
80+
if dist:
81+
device = torch.device(f'cuda:{get_local_rank()}' if torch.cuda.is_available() else 'cpu')
82+
else:
83+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
84+
85+
model = LSC_Trainer(num_nodes).to(device)
86+
87+
if dist:
88+
model = torch.nn.parallel.DistributedDataParallel(model,
89+
device_ids=[get_local_rank()],
90+
output_device=get_local_rank())
91+
92+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
93+
94+
criterion = torch.nn.MSELoss()
95+
96+
if primary:
97+
file_name = "SYNTHETIC_LOGS/SYNTHETIC_"+str(num_nodes) + "_" + str(num_edges) +".log"
98+
99+
logger = open(file_name, 'w')
100+
101+
print("Writing log information to ", file_name, flush=True)
102+
103+
for epoch in range(0, 5):
104+
105+
epoch_loss = 0
106+
epoch_start_time = time.perf_counter()
107+
batch_times = AverageTracker()
108+
109+
loss_tracker = AverageTracker()
110+
111+
if (dist):
112+
train_loader.sampler.set_epoch(epoch)
113+
for i, data in enumerate(train_loader):
114+
115+
_time_start = time.perf_counter()
116+
data = data.to(device)
117+
y = data.y
118+
119+
pred = model(data)
120+
loss = criterion(y, pred)
121+
loss_tracker.update(loss.item())
122+
optimizer.zero_grad()
123+
loss.backward()
124+
optimizer.step()
125+
126+
127+
#if dist and sync:
128+
#torch.distributed.barrier() # This ensures that Global Mini Batches are synced
129+
130+
if rank == 0:
131+
batch_times.update(time.perf_counter() - _time_start)
132+
#print("Mini Batch Times ", i,": \t", batch_times.mean(), "LOSS: \t", loss_tracker.mean(), flush=True)
133+
if dist:
134+
torch.distributed.barrier()
135+
136+
if primary:
137+
message = "Epoch {}: Total elapsed time {:.3f} \t Average Mini Batch Time {:.3f} \n"
138+
epoch_time = time.perf_counter() - epoch_start_time
139+
140+
logger.write(message.format(epoch, epoch_time, batch_times.mean()))
141+
logger.flush()
142+
print(message.format(epoch, epoch_time, batch_times.mean()), flush=True)
143+
144+
else:
145+
if primary:
146+
message = "Epoch {}: Total elapsed time {:.3f} \t Average Mini Batch Time {:.3f} \n"
147+
epoch_time = time.perf_counter() - epoch_start_time
148+
149+
logger.write(message.format(epoch, epoch_time, batch_times.mean()))
150+
logger.flush()
151+
print(message.format(epoch, epoch_time, batch_times.mean()), flush=True)
152+
153+
if primary:
154+
logger.close()
155+
156+
157+
if __name__ == '__main__':
158+
main(mb_size, False, False)

0 commit comments

Comments
 (0)