-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
68 lines (60 loc) · 2.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import argparse
import pickle
import yaml
import torch
from glob import glob
from tqdm import tqdm
from easydict import EasyDict
from utils.dataloader import DataLoader
from models.epsnet import *
from utils.datasets import ConformationDataset
from torch.nn.utils import clip_grad_norm_
from utils.common import get_optimizer, get_scheduler
from utils.transforms import *
from utils.misc import *
device = 'cuda:0'
torch.set_printoptions(precision=2, sci_mode=False)
# init model
model = get_model(config.model).to(device)
# set data_path
data_path = 'cross_set_v2/zinc_all.pkl'
transforms = CountNodesPerGraph()
val_set = ConformationDataset(data_path, transform=None)
keys = ["atom_type_r", "edge_index_r", "edge_type_r", "pos_r", "smiles_r", "rdmol_r"]
val_tmp = DataLoader(val_set, batch_size=16, shuffle=True, num_workers=4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10)
iter_num = 10000
model.train()
for it in range(iter_num):
losses = []
for batch in val_tmp:
optimizer.zero_grad()
#batch = next(val_it).to(device)
batch = batch.to_data_list()
batch2 = copy.deepcopy(batch)
for i in range(len(batch)):
for key in keys:
del batch2[i][key[:-2]]
batch2[i][key[:-2]] = batch[i][key]
del batch[i][key]
del batch2[i][key]
batch = Batch.from_data_list(batch)
batch2 = Batch.from_data_list(batch2)
loss = model.get_loss(
query_batch=copy.deepcopy(batch).to(device),
reference_batch=copy.deepcopy(batch2).to(device),
)
if torch.isnan(loss):
print("nan detected...")
else:
loss.backward()
clip_grad_norm_(model.parameters(), max_norm=0.1)
optimizer.step()
losses.append(loss.item())
loss_mean = np.mean(losses)
print(f'[iter {it+1}], loss: {loss_mean:.4f}')
scheduler.step(loss_mean)
if it%1 == 0:
torch.save(model.state_dict(), f'./param/{it+1}.pt',)