Skip to content

Commit 924eff4

Browse files
committed
Added lbann data generator
1 parent ef56349 commit 924eff4

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import torch
3+
import ogb
4+
import matplotlib.pyplot as plt
5+
from ogb.graphproppred.mol_encoder import AtomEncoder,BondEncoder
6+
from tqdm import tqdm
7+
from torch_geometric.data import Data
8+
9+
10+
from ogb.lsc import PCQM4MDataset
11+
from ogb.utils import smiles2graph
12+
13+
14+
# convert each SMILES string into a molecular graph object by calling smiles2graph
15+
# This takes a while (a few hours) for the first run
16+
dataset = PCQM4MDataset(root = ROOT, smiles2graph = smiles2graph)
17+
18+
_data = torch.load("pcqm4m_kddcup2021/processed/data_processed")
19+
data_split_indices = torch.load('pcqm4m_kddcup2021/split_dict.pt')
20+
21+
training = data_split_indices['train']
22+
validation = data_split_indices['valid']
23+
24+
25+
validation_data = []
26+
for index in tqdm(validation):
27+
graph = _data['graphs'][index]
28+
num_nodes = graph['num_nodes']
29+
homolumogap = _data['labels'][index]
30+
data = Data()
31+
32+
33+
data.__num_nodes__ = int(graph['num_nodes'])
34+
data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
35+
data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
36+
data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
37+
data.y = torch.Tensor([homolumogap])
38+
39+
validation_data.append(data)
40+
with open('LSC_PCQM4M/valid.bin','wb') as f:
41+
pickle.dump(validation_data, f)
42+
43+
44+
45+
46+
for name, _set in [("training", training), ("validation",validation)]:
47+
48+
filename = 'LBANN_Data_'+name+'.bin'
49+
count = 1
50+
fp = np.memmap(filename, dtype='float32', mode='w+', shape=(len(_set),1101))
51+
for index in tqdm(training):
52+
graph = _data['graphs'][index]
53+
num_nodes = graph['num_nodes']
54+
55+
nodes = -1 * np.ones((51,9), dtype=np.float32)
56+
nodes[: num_nodes, :] = np.float32(graph['node_feat'])
57+
58+
edges = -1 * np.ones((118,3), dtype=np.float32)
59+
60+
num_edges = graph['edge_feat'].shape[0]
61+
edges[:num_edges,:] = graph['edge_feat']
62+
63+
64+
sources = -1 * np.ones(118, dtype=np.float32)
65+
targets = -1 * np.ones(118, dtype=np.float32)
66+
sources[:num_edges] = graph['edge_index'][0]
67+
targets[:num_edges] = graph['edge_index'][1]
68+
69+
mask = np.zeros(51, dtype=np.float32)
70+
mask[:num_nodes] = 1
71+
label = np.array([training_labels[index]], dtype=np.float32)
72+
fp[index, :] = np.concatenate([(nodes.T).flatten(), (edges.T).flatten(), sources, targets, mask, label])
73+
74+
if count % 10000 == 0:
75+
fp.flush()
76+
count += 1
77+
fp.flush()

0 commit comments

Comments
 (0)