-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
86 lines (67 loc) · 2.91 KB
/
train.py
File metadata and controls
86 lines (67 loc) · 2.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import math
from model.encoder import Encoder
from util.dataset import PlanDataset
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchsummary import summary
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = PlanDataset(root_dir="data/deep_cardinality")
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size
# train_temp = [dataset[i] for i in range(10)]
# test_temp = [dataset[i] for i in range(5)]
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2)
# test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)
encoder = Encoder(d_feature=9 + 6 + 64, d_model=256, d_ff=128, N=4).double()
summary(encoder)
criterion = nn.MSELoss()
optimizer = optim.Adam(encoder.parameters(), lr=0.001)
epoch_size = 2
def train():
result = []
for epoch in range(epoch_size):
print("epoch : ", epoch)
running_loss = 0.0
for i, data in enumerate(train_dataset):
tree, nodemat, leafmat, label = data
optimizer.zero_grad()
output = encoder(tree, nodemat.double(), leafmat.double())
# output = output
if len(output.shape) > 1 or len(label.shape) > 1:
print("output: {} ,label: {}".format(len(output.shape), len(label.shape)))
loss = criterion(output, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
if math.isnan(running_loss):
print("nan: ", i, "\t", running_loss)
if i % 200 == 0 and i != 0:
print("[%d, %5d] loss: %4f" % (epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
test_loss = 0.0
with torch.no_grad():
for i, data in enumerate(test_dataset):
tree, nodemat, leafmat, label = data
test_output = encoder(tree, nodemat, leafmat)
if epoch == epoch_size - 1:
result.append((label, test_output))
loss = criterion(test_output, label)
test_loss += loss.item()
if i % 200 == 0 and i != 0:
print("test loss: ", test_loss / test_size)
return result
def dataset_test():
for i, data in enumerate(test_dataset):
tree, nodemat, leafmat, label = data
print(label)
if __name__ == "__main__":
result = train()
# result = [(1.1, 2.2), (3.3, 4.4), (5.5, 6.6)]
with open("data/dmodel256/resutldeep_cv1.0dff128-e2-N4-lr0.001.txt", "w") as f:
f.write("\n".join("{} {}".format(x[0].item(), x[1].item()) for x in result))
# torch.save(encoder, "model_parameter/encoderv1.0.pkl")
# dataset_test()