Skip to content

Commit 1ec4f74

Browse files
authored
Merge pull request #60 from snap-stanford/external_dataset
DatasetSaver class, binary loading and saving, and DGL v0.5
2 parents 5b8bdb1 + af9a973 commit 1ec4f74

28 files changed

Lines changed: 2447 additions & 1098 deletions

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,8 @@ dataset/
1313
*.sh
1414
*analyze*
1515
*random.py
16+
*RELEASE_*
17+
*.csv.gz
18+
*.zip
19+
*submission_
20+
*.npz

README.md

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
<p align="center">
2-
<img width="40%" src="https://snap-stanford.github.io/ogb-web/assets/img/OGB_rectangle.png" />
1+
<p align='center'>
2+
<img width='40%' src='https://snap-stanford.github.io/ogb-web/assets/img/OGB_rectangle.png' />
33
</p>
44

55
--------------------------------------------------------------------------------
@@ -12,8 +12,8 @@
1212
The Open Graph Benchmark (OGB) is a collection of benchmark datasets, data loaders, and evaluators for graph machine learning. Datasets cover a variety of graph machine learning tasks and real-world applications.
1313
The OGB data loaders are fully compatible with popular graph deep learning frameworks, including [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) and [Deep Graph Library (DGL)](https://www.dgl.ai/). They provide automatic dataset downloading, standardized dataset splits, and unified performance evaluation.
1414

15-
<p align="center">
16-
<img width="80%" src="https://snap-stanford.github.io/ogb-web/assets/img/ogb_overview.png" />
15+
<p align='center'>
16+
<img width='80%' src='https://snap-stanford.github.io/ogb-web/assets/img/ogb_overview.png' />
1717
</p>
1818

1919
OGB aims to provide graph datasets that cover important graph machine learning tasks, diverse dataset scale, and rich domains.
@@ -24,8 +24,8 @@ OGB aims to provide graph datasets that cover important graph machine learning t
2424

2525
**Rich domains:** Graph datasets come from diverse domains ranging from scientific ones to social/information networks, and also include heterogeneous knowledge graphs.
2626

27-
<p align="center">
28-
<img width="70%" src="https://snap-stanford.github.io/ogb-web/assets/img/dataset_overview.png" />
27+
<p align='center'>
28+
<img width='70%' src='https://snap-stanford.github.io/ogb-web/assets/img/dataset_overview.png' />
2929
</p>
3030

3131
OGB is an on-going effort, and we are planning to increase our coverage in the future.
@@ -38,7 +38,7 @@ The release note is available [here](https://github.com/snap-stanford/ogb/releas
3838
#### Requirements
3939
- Python>=3.5
4040
- PyTorch>=1.2
41-
- DGL>=0.4.1 or torch-geometric>=1.3.1
41+
- DGL>=0.5.0 or torch-geometric>=1.6.0
4242
- Numpy>=1.16.0
4343
- pandas>=0.24.0
4444
- urllib3>=1.24.0
@@ -77,12 +77,12 @@ Below, on PyTorch Geometric, we see that a few lines of code is sufficient to pr
7777
from ogb.graphproppred import PygGraphPropPredDataset
7878
from torch_geometric.data import DataLoader
7979

80-
dataset = PygGraphPropPredDataset(name = "ogbg-molhiv")
80+
dataset = PygGraphPropPredDataset(name = 'ogbg-molhiv')
8181

8282
split_idx = dataset.get_idx_split()
83-
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=32, shuffle=True)
84-
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=32, shuffle=False)
85-
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=32, shuffle=False)
83+
train_loader = DataLoader(dataset[split_idx['train']], batch_size=32, shuffle=True)
84+
valid_loader = DataLoader(dataset[split_idx['valid']], batch_size=32, shuffle=False)
85+
test_loader = DataLoader(dataset[split_idx['test']], batch_size=32, shuffle=False)
8686
```
8787

8888
#### (2) Evaluators
@@ -91,12 +91,12 @@ The standardized evaluation protocol allows researchers to reliably compare thei
9191
```python
9292
from ogb.graphproppred import Evaluator
9393

94-
evaluator = Evaluator(name = "ogbg-molhiv")
94+
evaluator = Evaluator(name = 'ogbg-molhiv')
9595
# You can learn the input and output format specification of the evaluator as follows.
9696
# print(evaluator.expected_input_format)
9797
# print(evaluator.expected_output_format)
98-
input_dict = {"y_true": y_true, "y_pred": y_pred}
99-
result_dict = evaluator.eval(input_dict) # E.g., {"rocauc": 0.7321}
98+
input_dict = {'y_true': y_true, 'y_pred': y_pred}
99+
result_dict = evaluator.eval(input_dict) # E.g., {'rocauc': 0.7321}
100100
```
101101

102102
## Citing OGB

ogb/graphproppred/dataset.py

Lines changed: 94 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,56 @@
33
import numpy as np
44
import os.path as osp
55
from ogb.utils.url import decide_download, download_url, extract_zip
6-
from ogb.io.read_graph_raw import read_csv_graph_raw
6+
from ogb.io.read_graph_raw import read_csv_graph_raw, read_binary_graph_raw
77
import torch
88

99
class GraphPropPredDataset(object):
10-
def __init__(self, name, root = "dataset"):
11-
self.name = name ## original name, e.g., ogbg-mol-tox21
12-
self.dir_name = "_".join(name.split("-")) ## replace hyphen with underline, e.g., ogbg_mol_tox21
13-
14-
self.original_root = root
15-
self.root = osp.join(root, self.dir_name)
16-
17-
self.meta_info = pd.read_csv(os.path.join(os.path.dirname(__file__), "master.csv"), index_col = 0)
18-
if not self.name in self.meta_info:
19-
print(self.name)
20-
error_mssg = "Invalid dataset name {}.\n".format(self.name)
21-
error_mssg += "Available datasets are as follows:\n"
22-
error_mssg += "\n".join(self.meta_info.keys())
23-
raise ValueError(error_mssg)
10+
def __init__(self, name, root = 'dataset', meta_dict = None):
11+
'''
12+
- name (str): name of the dataset
13+
- root (str): root directory to store the dataset folder
14+
15+
- meta_dict: dictionary that stores all the meta-information about data. Default is None,
16+
but when something is passed, it uses its information. Useful for debugging for external contributers.
17+
'''
18+
19+
self.name = name ## original name, e.g., ogbg-hib
20+
21+
if meta_dict is None:
22+
self.dir_name = '_'.join(name.split('-')) ## replace hyphen with underline, e.g., ogbg_hiv
23+
self.original_root = root
24+
self.root = osp.join(root, self.dir_name)
25+
26+
master = pd.read_csv(os.path.join(os.path.dirname(__file__), 'master.csv'), index_col = 0)
27+
if not self.name in master:
28+
error_mssg = 'Invalid dataset name {}.\n'.format(self.name)
29+
error_mssg += 'Available datasets are as follows:\n'
30+
error_mssg += '\n'.join(master.keys())
31+
raise ValueError(error_mssg)
32+
self.meta_info = master[self.name]
33+
34+
else:
35+
self.dir_name = meta_dict['dir_path']
36+
self.original_root = ''
37+
self.root = meta_dict['dir_path']
38+
self.meta_info = meta_dict
2439

2540
# check version
2641
# First check whether the dataset has been already downloaded or not.
2742
# If so, check whether the dataset version is the newest or not.
2843
# If the dataset is not the newest version, notify this to the user.
29-
if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info[self.name]['version']) + '.txt'))):
44+
if osp.isdir(self.root) and (not osp.exists(osp.join(self.root, 'RELEASE_v' + str(self.meta_info['version']) + '.txt'))):
3045
print(self.name + ' has been updated.')
31-
if input("Will you update the dataset now? (y/N)\n").lower() == "y":
46+
if input('Will you update the dataset now? (y/N)\n').lower() == 'y':
3247
shutil.rmtree(self.root)
3348

34-
self.download_name = self.meta_info[self.name]["download_name"] ## name of downloaded file, e.g., tox21
49+
self.download_name = self.meta_info['download_name'] ## name of downloaded file, e.g., tox21
3550

36-
self.num_tasks = int(self.meta_info[self.name]["num tasks"])
37-
self.eval_metric = self.meta_info[self.name]["eval metric"]
38-
self.task_type = self.meta_info[self.name]["task type"]
39-
self.num_classes = self.meta_info[self.name]["num classes"]
51+
self.num_tasks = int(self.meta_info['num tasks'])
52+
self.eval_metric = self.meta_info['eval metric']
53+
self.task_type = self.meta_info['task type']
54+
self.num_classes = self.meta_info['num classes']
55+
self.binary = self.meta_info['binary'] == 'True'
4056

4157
super(GraphPropPredDataset, self).__init__()
4258

@@ -52,63 +68,81 @@ def pre_process(self):
5268
self.graphs, self.labels = loaded_dict['graphs'], loaded_dict['labels']
5369

5470
else:
55-
### download
56-
url = self.meta_info[self.name]["url"]
57-
if decide_download(url):
58-
path = download_url(url, self.original_root)
59-
extract_zip(path, self.original_root)
60-
os.unlink(path)
61-
# delete folder if there exists
62-
try:
63-
shutil.rmtree(self.root)
64-
except:
65-
pass
66-
shutil.move(osp.join(self.original_root, self.download_name), self.root)
71+
### check download
72+
if self.binary:
73+
# npz format
74+
has_necessary_file = osp.exists(osp.join(self.root, 'raw', 'data.npz'))
6775
else:
68-
print("Stop download.")
69-
exit(-1)
76+
# csv file
77+
has_necessary_file = osp.exists(osp.join(self.root, 'raw', 'edge.csv.gz'))
78+
79+
### download
80+
if not has_necessary_file:
81+
url = self.meta_info['url']
82+
if decide_download(url):
83+
path = download_url(url, self.original_root)
84+
extract_zip(path, self.original_root)
85+
os.unlink(path)
86+
# delete folder if there exists
87+
try:
88+
shutil.rmtree(self.root)
89+
except:
90+
pass
91+
shutil.move(osp.join(self.original_root, self.download_name), self.root)
92+
else:
93+
print('Stop download.')
94+
exit(-1)
7095

7196
### preprocess
72-
add_inverse_edge = self.meta_info[self.name]["add_inverse_edge"] == "True"
97+
add_inverse_edge = self.meta_info['add_inverse_edge'] == 'True'
7398

74-
if self.meta_info[self.name]["additional node files"] == 'None':
99+
if self.meta_info['additional node files'] == 'None':
75100
additional_node_files = []
76101
else:
77-
additional_node_files = self.meta_info[self.name]["additional node files"].split(',')
102+
additional_node_files = self.meta_info['additional node files'].split(',')
78103

79-
if self.meta_info[self.name]["additional edge files"] == 'None':
104+
if self.meta_info['additional edge files'] == 'None':
80105
additional_edge_files = []
81106
else:
82-
additional_edge_files = self.meta_info[self.name]["additional edge files"].split(',')
83-
84-
self.graphs = read_csv_graph_raw(raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
85-
107+
additional_edge_files = self.meta_info['additional edge files'].split(',')
108+
109+
if self.binary:
110+
self.graphs = read_binary_graph_raw(raw_dir, add_inverse_edge = add_inverse_edge)
111+
else:
112+
self.graphs = read_csv_graph_raw(raw_dir, add_inverse_edge = add_inverse_edge, additional_node_files = additional_node_files, additional_edge_files = additional_edge_files)
86113

87114
if self.task_type == 'subtoken prediction':
88-
labels_joined = pd.read_csv(osp.join(raw_dir, "graph-label.csv.gz"), compression="gzip", header = None).values
115+
labels_joined = pd.read_csv(osp.join(raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values
89116
# need to split each element into subtokens
90117
self.labels = [str(labels_joined[i][0]).split(' ') for i in range(len(labels_joined))]
91118
else:
92-
self.labels = pd.read_csv(osp.join(raw_dir, "graph-label.csv.gz"), compression="gzip", header = None).values
119+
if self.binary:
120+
self.labels = np.load(osp.join(raw_dir, 'graph-label.npz'))['graph_label']
121+
else:
122+
self.labels = pd.read_csv(osp.join(raw_dir, 'graph-label.csv.gz'), compression='gzip', header = None).values
93123

94124
print('Saving...')
95125
torch.save({'graphs': self.graphs, 'labels': self.labels}, pre_processed_file_path, pickle_protocol=4)
96126

97127

98128
def get_idx_split(self, split_type = None):
99129
if split_type is None:
100-
split_type = self.meta_info[self.name]["split"]
130+
split_type = self.meta_info['split']
101131

102-
path = osp.join(self.root, "split", split_type)
132+
path = osp.join(self.root, 'split', split_type)
133+
134+
# short-cut if split_dict.pt exists
135+
if os.path.isfile(os.path.join(path, 'split_dict.pt')):
136+
return torch.load(os.path.join(path, 'split_dict.pt'))
103137

104-
train_idx = pd.read_csv(osp.join(path, "train.csv.gz"), compression="gzip", header = None).values.T[0]
105-
valid_idx = pd.read_csv(osp.join(path, "valid.csv.gz"), compression="gzip", header = None).values.T[0]
106-
test_idx = pd.read_csv(osp.join(path, "test.csv.gz"), compression="gzip", header = None).values.T[0]
138+
train_idx = pd.read_csv(osp.join(path, 'train.csv.gz'), compression='gzip', header = None).values.T[0]
139+
valid_idx = pd.read_csv(osp.join(path, 'valid.csv.gz'), compression='gzip', header = None).values.T[0]
140+
test_idx = pd.read_csv(osp.join(path, 'test.csv.gz'), compression='gzip', header = None).values.T[0]
107141

108-
return {"train": train_idx, "valid": valid_idx, "test": test_idx}
142+
return {'train': train_idx, 'valid': valid_idx, 'test': test_idx}
109143

110144
def __getitem__(self, idx):
111-
"""Get datapoint with index"""
145+
'''Get datapoint with index'''
112146

113147
if isinstance(idx, (int, np.integer)):
114148
return self.graphs[idx], self.labels[idx]
@@ -117,20 +151,20 @@ def __getitem__(self, idx):
117151
'Only integer is valid index (got {}).'.format(type(idx).__name__))
118152

119153
def __len__(self):
120-
"""Length of the dataset
154+
'''Length of the dataset
121155
Returns
122156
-------
123157
int
124158
Length of Dataset
125-
"""
159+
'''
126160
return len(self.graphs)
127161

128162
def __repr__(self): # pragma: no cover
129163
return '{}({})'.format(self.__class__.__name__, len(self))
130164

131165

132-
if __name__ == "__main__":
133-
dataset = GraphPropPredDataset(name = "ogbg-code")
166+
if __name__ == '__main__':
167+
dataset = GraphPropPredDataset(name = 'ogbg-code')
134168
# target_list = np.array([len(label) for label in dataset.labels])
135169
# print(np.sum(target_list == 1)/ float(len(target_list)))
136170
# print(np.sum(target_list == 2)/ float(len(target_list)))
@@ -144,8 +178,8 @@ def __repr__(self): # pragma: no cover
144178
print(split_index)
145179
# print(dataset)
146180
# print(dataset[2])
147-
# print(split_index["train"])
148-
# print(split_index["valid"])
149-
# print(split_index["test"])
181+
# print(split_index['train'])
182+
# print(split_index['valid'])
183+
# print(split_index['test'])
150184

151185

0 commit comments

Comments
 (0)