33import numpy as np
44import os .path as osp
55from ogb .utils .url import decide_download , download_url , extract_zip
6- from ogb .io .read_graph_raw import read_csv_graph_raw
6+ from ogb .io .read_graph_raw import read_csv_graph_raw , read_binary_graph_raw
77import torch
88
99class GraphPropPredDataset (object ):
10- def __init__ (self , name , root = "dataset" ):
11- self .name = name ## original name, e.g., ogbg-mol-tox21
12- self .dir_name = "_" .join (name .split ("-" )) ## replace hyphen with underline, e.g., ogbg_mol_tox21
13-
14- self .original_root = root
15- self .root = osp .join (root , self .dir_name )
16-
17- self .meta_info = pd .read_csv (os .path .join (os .path .dirname (__file__ ), "master.csv" ), index_col = 0 )
18- if not self .name in self .meta_info :
19- print (self .name )
20- error_mssg = "Invalid dataset name {}.\n " .format (self .name )
21- error_mssg += "Available datasets are as follows:\n "
22- error_mssg += "\n " .join (self .meta_info .keys ())
23- raise ValueError (error_mssg )
10+ def __init__ (self , name , root = 'dataset' , meta_dict = None ):
11+ '''
12+ - name (str): name of the dataset
13+ - root (str): root directory to store the dataset folder
14+
15+ - meta_dict: dictionary that stores all the meta-information about data. Default is None,
16+ but when something is passed, it uses its information. Useful for debugging for external contributers.
17+ '''
18+
19+ self .name = name ## original name, e.g., ogbg-hib
20+
21+ if meta_dict is None :
22+ self .dir_name = '_' .join (name .split ('-' )) ## replace hyphen with underline, e.g., ogbg_hiv
23+ self .original_root = root
24+ self .root = osp .join (root , self .dir_name )
25+
26+ master = pd .read_csv (os .path .join (os .path .dirname (__file__ ), 'master.csv' ), index_col = 0 )
27+ if not self .name in master :
28+ error_mssg = 'Invalid dataset name {}.\n ' .format (self .name )
29+ error_mssg += 'Available datasets are as follows:\n '
30+ error_mssg += '\n ' .join (master .keys ())
31+ raise ValueError (error_mssg )
32+ self .meta_info = master [self .name ]
33+
34+ else :
35+ self .dir_name = meta_dict ['dir_path' ]
36+ self .original_root = ''
37+ self .root = meta_dict ['dir_path' ]
38+ self .meta_info = meta_dict
2439
2540 # check version
2641 # First check whether the dataset has been already downloaded or not.
2742 # If so, check whether the dataset version is the newest or not.
2843 # If the dataset is not the newest version, notify this to the user.
29- if osp .isdir (self .root ) and (not osp .exists (osp .join (self .root , 'RELEASE_v' + str (self .meta_info [self . name ][ 'version' ]) + '.txt' ))):
44+ if osp .isdir (self .root ) and (not osp .exists (osp .join (self .root , 'RELEASE_v' + str (self .meta_info ['version' ]) + '.txt' ))):
3045 print (self .name + ' has been updated.' )
31- if input (" Will you update the dataset now? (y/N)\n " ).lower () == "y" :
46+ if input (' Will you update the dataset now? (y/N)\n ' ).lower () == 'y' :
3247 shutil .rmtree (self .root )
3348
34- self .download_name = self .meta_info [self . name ][ " download_name" ] ## name of downloaded file, e.g., tox21
49+ self .download_name = self .meta_info [' download_name' ] ## name of downloaded file, e.g., tox21
3550
36- self .num_tasks = int (self .meta_info [self .name ]["num tasks" ])
37- self .eval_metric = self .meta_info [self .name ]["eval metric" ]
38- self .task_type = self .meta_info [self .name ]["task type" ]
39- self .num_classes = self .meta_info [self .name ]["num classes" ]
51+ self .num_tasks = int (self .meta_info ['num tasks' ])
52+ self .eval_metric = self .meta_info ['eval metric' ]
53+ self .task_type = self .meta_info ['task type' ]
54+ self .num_classes = self .meta_info ['num classes' ]
55+ self .binary = self .meta_info ['binary' ] == 'True'
4056
4157 super (GraphPropPredDataset , self ).__init__ ()
4258
@@ -52,63 +68,81 @@ def pre_process(self):
5268 self .graphs , self .labels = loaded_dict ['graphs' ], loaded_dict ['labels' ]
5369
5470 else :
55- ### download
56- url = self .meta_info [self .name ]["url" ]
57- if decide_download (url ):
58- path = download_url (url , self .original_root )
59- extract_zip (path , self .original_root )
60- os .unlink (path )
61- # delete folder if there exists
62- try :
63- shutil .rmtree (self .root )
64- except :
65- pass
66- shutil .move (osp .join (self .original_root , self .download_name ), self .root )
71+ ### check download
72+ if self .binary :
73+ # npz format
74+ has_necessary_file = osp .exists (osp .join (self .root , 'raw' , 'data.npz' ))
6775 else :
68- print ("Stop download." )
69- exit (- 1 )
76+ # csv file
77+ has_necessary_file = osp .exists (osp .join (self .root , 'raw' , 'edge.csv.gz' ))
78+
79+ ### download
80+ if not has_necessary_file :
81+ url = self .meta_info ['url' ]
82+ if decide_download (url ):
83+ path = download_url (url , self .original_root )
84+ extract_zip (path , self .original_root )
85+ os .unlink (path )
86+ # delete folder if there exists
87+ try :
88+ shutil .rmtree (self .root )
89+ except :
90+ pass
91+ shutil .move (osp .join (self .original_root , self .download_name ), self .root )
92+ else :
93+ print ('Stop download.' )
94+ exit (- 1 )
7095
7196 ### preprocess
72- add_inverse_edge = self .meta_info [self . name ][ " add_inverse_edge" ] == " True"
97+ add_inverse_edge = self .meta_info [' add_inverse_edge' ] == ' True'
7398
74- if self .meta_info [self . name ][ " additional node files" ] == 'None' :
99+ if self .meta_info [' additional node files' ] == 'None' :
75100 additional_node_files = []
76101 else :
77- additional_node_files = self .meta_info [self . name ][ " additional node files" ].split (',' )
102+ additional_node_files = self .meta_info [' additional node files' ].split (',' )
78103
79- if self .meta_info [self . name ][ " additional edge files" ] == 'None' :
104+ if self .meta_info [' additional edge files' ] == 'None' :
80105 additional_edge_files = []
81106 else :
82- additional_edge_files = self .meta_info [self .name ]["additional edge files" ].split (',' )
83-
84- self .graphs = read_csv_graph_raw (raw_dir , add_inverse_edge = add_inverse_edge , additional_node_files = additional_node_files , additional_edge_files = additional_edge_files )
85-
107+ additional_edge_files = self .meta_info ['additional edge files' ].split (',' )
108+
109+ if self .binary :
110+ self .graphs = read_binary_graph_raw (raw_dir , add_inverse_edge = add_inverse_edge )
111+ else :
112+ self .graphs = read_csv_graph_raw (raw_dir , add_inverse_edge = add_inverse_edge , additional_node_files = additional_node_files , additional_edge_files = additional_edge_files )
86113
87114 if self .task_type == 'subtoken prediction' :
88- labels_joined = pd .read_csv (osp .join (raw_dir , " graph-label.csv.gz" ), compression = " gzip" , header = None ).values
115+ labels_joined = pd .read_csv (osp .join (raw_dir , ' graph-label.csv.gz' ), compression = ' gzip' , header = None ).values
89116 # need to split each element into subtokens
90117 self .labels = [str (labels_joined [i ][0 ]).split (' ' ) for i in range (len (labels_joined ))]
91118 else :
92- self .labels = pd .read_csv (osp .join (raw_dir , "graph-label.csv.gz" ), compression = "gzip" , header = None ).values
119+ if self .binary :
120+ self .labels = np .load (osp .join (raw_dir , 'graph-label.npz' ))['graph_label' ]
121+ else :
122+ self .labels = pd .read_csv (osp .join (raw_dir , 'graph-label.csv.gz' ), compression = 'gzip' , header = None ).values
93123
94124 print ('Saving...' )
95125 torch .save ({'graphs' : self .graphs , 'labels' : self .labels }, pre_processed_file_path , pickle_protocol = 4 )
96126
97127
98128 def get_idx_split (self , split_type = None ):
99129 if split_type is None :
100- split_type = self .meta_info [self . name ][ " split" ]
130+ split_type = self .meta_info [' split' ]
101131
102- path = osp .join (self .root , "split" , split_type )
132+ path = osp .join (self .root , 'split' , split_type )
133+
134+ # short-cut if split_dict.pt exists
135+ if os .path .isfile (os .path .join (path , 'split_dict.pt' )):
136+ return torch .load (os .path .join (path , 'split_dict.pt' ))
103137
104- train_idx = pd .read_csv (osp .join (path , " train.csv.gz" ), compression = " gzip" , header = None ).values .T [0 ]
105- valid_idx = pd .read_csv (osp .join (path , " valid.csv.gz" ), compression = " gzip" , header = None ).values .T [0 ]
106- test_idx = pd .read_csv (osp .join (path , " test.csv.gz" ), compression = " gzip" , header = None ).values .T [0 ]
138+ train_idx = pd .read_csv (osp .join (path , ' train.csv.gz' ), compression = ' gzip' , header = None ).values .T [0 ]
139+ valid_idx = pd .read_csv (osp .join (path , ' valid.csv.gz' ), compression = ' gzip' , header = None ).values .T [0 ]
140+ test_idx = pd .read_csv (osp .join (path , ' test.csv.gz' ), compression = ' gzip' , header = None ).values .T [0 ]
107141
108- return {" train" : train_idx , " valid" : valid_idx , " test" : test_idx }
142+ return {' train' : train_idx , ' valid' : valid_idx , ' test' : test_idx }
109143
110144 def __getitem__ (self , idx ):
111- """ Get datapoint with index"""
145+ ''' Get datapoint with index'''
112146
113147 if isinstance (idx , (int , np .integer )):
114148 return self .graphs [idx ], self .labels [idx ]
@@ -117,20 +151,20 @@ def __getitem__(self, idx):
117151 'Only integer is valid index (got {}).' .format (type (idx ).__name__ ))
118152
119153 def __len__ (self ):
120- """ Length of the dataset
154+ ''' Length of the dataset
121155 Returns
122156 -------
123157 int
124158 Length of Dataset
125- """
159+ '''
126160 return len (self .graphs )
127161
128162 def __repr__ (self ): # pragma: no cover
129163 return '{}({})' .format (self .__class__ .__name__ , len (self ))
130164
131165
132- if __name__ == " __main__" :
133- dataset = GraphPropPredDataset (name = " ogbg-code" )
166+ if __name__ == ' __main__' :
167+ dataset = GraphPropPredDataset (name = ' ogbg-code' )
134168 # target_list = np.array([len(label) for label in dataset.labels])
135169 # print(np.sum(target_list == 1)/ float(len(target_list)))
136170 # print(np.sum(target_list == 2)/ float(len(target_list)))
@@ -144,8 +178,8 @@ def __repr__(self): # pragma: no cover
144178 print (split_index )
145179 # print(dataset)
146180 # print(dataset[2])
147- # print(split_index[" train" ])
148- # print(split_index[" valid" ])
149- # print(split_index[" test" ])
181+ # print(split_index[' train' ])
182+ # print(split_index[' valid' ])
183+ # print(split_index[' test' ])
150184
151185
0 commit comments