-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
48 lines (36 loc) · 1.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import json
import torch
from rdkit import RDLogger
from utils.datasets import MOLECULAR_DATASETS, BASE_DIR, load_dataset
from utils.train import train, evaluate
from utils.evaluate import count_parameters
from models import pgc_marg
MODELS = {
**pgc_marg.MODELS
}
BASE_DIR_TRN = f'{BASE_DIR}trn/'
if __name__ == '__main__':
torch.set_float32_matmul_precision('medium')
RDLogger.DisableLog('rdApp.*')
dataset = 'qm9'
backends = [
# 'marg_sort_btree'
# 'marg_sort_vtree'
# 'marg_sort_rtree'
# 'marg_sort_ptree'
'marg_sort_ctree'
]
for backend in backends:
with open(f'config/{dataset}/{backend}.json', 'r') as f:
hyperpars = json.load(f)
hyperpars['atom_list'] = MOLECULAR_DATASETS[dataset]['atom_list']
loaders = load_dataset(dataset, hyperpars['batch_size'], [0.8, 0.1, 0.1], order=hyperpars['order'])
model = MODELS[hyperpars['model']](loaders['loader_trn'], hyperpars['model_hpars'])
print(dataset)
print(json.dumps(hyperpars, indent=4))
print(model)
print(f'The number of parameters is {count_parameters(model)}.')
print(hyperpars['order'])
train(model, loaders, hyperpars, BASE_DIR_TRN)
metrics = evaluate(loaders, hyperpars, BASE_DIR_TRN, compute_nll=True)
print("\n".join(f'{key:<16}{value:>10.4f}' for key, value in metrics.items()))