-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtpp_train.py
148 lines (121 loc) · 6.1 KB
/
tpp_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
from pathlib import Path
import os
import json
from models.prob_decoders import *
from models.embeddings import *
from models.hist_encoders import *
from datasets.tpp_loader import *
from models.tpp_warper import TPPWarper
from trainers.trainer import Trainer
from trainers.adversarial_trainer import AdvTrainer
def SetSeed(seed):
"""function used to set a random seed
Arguments:
seed {int} -- seed number, will set to torch and numpy
"""
import torch
import numpy
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
numpy.random.seed(seed)
TIME_EMB = ['Trigo', 'Linear']
PROB_DEC = ['CNF','Diffusion','GAN','ScoreMatch','VAE','LogNorm','Gompt','Gaussian','Weibull','FNN', 'THP', 'SAHP']
# NOTE: The given THP and SAHP use different type-modeling methods (type-wise intensity modelling), while others model all the type in a single sequence.
# So the final metric evaluation will be in a different protocol.
HIST_ENC = ['LSTM', 'Attention']
parser = argparse.ArgumentParser(prog="Attentive Diffusion Temporal Point Process (training)",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# Basic
parser.add_argument('--log_dir', type=str, metavar='DIR',
help='Directory where models and logs will be saved.', default='experiments/')
parser.add_argument('--dataset_dir', type=str, metavar='DIR', default='./data/stackoverflow/',
choices=['./data/mooc/', './data/retweet/', './data/stackoverflow/', './data/synthetic_n5_c0.2/', './data/yelp/'],
help='Directory for dataset.')
# Training
parser.add_argument('--max_epoch', type=int, metavar='NUM', default=100,
help='The maximum epoch number for training.')
parser.add_argument('--lr', type=int, metavar='RATE', default=1e-3,
help='The leanring rate for training.')
parser.add_argument('--load_epoch', type=int, metavar='NUM', default=0,
help='Load the saved epoch number for continously training.')
parser.add_argument('--batch_size', type=int, metavar='SIZE', default=16,
help='Batch size for training.')
parser.add_argument('--val_batch_size', type=int, metavar='SIZE', default=8,
help='Batch size for validation, which should be smaller than training batch size because some metric requires MCMC sampling.')
parser.add_argument('--experiment_name', type=str, metavar='SIZE', default=None,
help='The experiment name, where the file where logs and models are saved will be called.')
# Model
parser.add_argument('--time_emb', type=str, metavar='NAME', default='Trigo', choices=TIME_EMB,
help='The time embedding which is used, choosen from {}.'.format(TIME_EMB))
parser.add_argument('--hist_enc', type=str, metavar='NAME', default='Attention', choices=HIST_ENC,
help='The history encoder which is used, choosen from {}.'.format(HIST_ENC))
parser.add_argument('--prob_dec', type=str, metavar='NAME', default='Diffusion', choices=PROB_DEC,
help='The probabilistic decoder which is used, choosen from {}.'.format(PROB_DEC))
parser.add_argument('--embed_size', type=int, metavar='SIZE', default=32,
help='Hidden dimension for the model.')
parser.add_argument('--layer_num', type=int, metavar='NUM', default=1,
help='Layer number for the model.')
parser.add_argument('--attention_heads', type=int, metavar='SIZE', default=4,
help='Attention heads for the attention history encoder, which should be set as a divisor of embed size.')
###
parser.add_argument('--gpu', type=int, metavar='DEVICE', default=6,
help='Gpu to use for training.')
parser.add_argument('--seed', type=int, metavar='SEED', default=42,
help='Random seed for training.')
## diffusion
parser.add_argument('--diff_steps', type=int, metavar='NUM', default=1000,
help='Diffusion steps in conditional temporal diffusion decoder.')
args = parser.parse_args()
args = vars(args)
if __name__ == '__main__':
SetSeed(args['seed'])
device = torch.device('cuda:{}'.format(args['gpu'])) if torch.cuda.is_available() else 'cpu'
data, event_type_num, seq_lengths, max_length, max_t, mean_log_dt, std_log_dt, max_dt \
= load_dataset(**args, device=device)
args['event_type_num'] = int(event_type_num)
args['max_length'] = int(max_length)
args['max_t'] = max_t
args['mean_log_dt'] = mean_log_dt
args['std_log_dt'] = std_log_dt
args['max_dt'] = max_dt
if args['experiment_name'] == None:
args['experiment_name'] = '{}_{}_{}_{}'.format(args['hist_enc'],
args['prob_dec'],
args['dataset_dir'].split('/')[-2],
args['seed'])
path = Path(args['log_dir'])/args['experiment_name']
path.mkdir(exist_ok = True, parents = True)
sv_param = os.path.join(path, 'model_param.json')
with open(sv_param, 'w') as file_obj:
json.dump(args, file_obj)
time_embedding, type_embedding, position_embedding = get_embedding(**args)
hist_encoder = get_encoder(**args)
prob_decoder = get_decoder(**args)
model = TPPWarper(time_embedding=time_embedding,
type_embedding=type_embedding,
position_embedding=position_embedding,
encoder=hist_encoder,
decoder=prob_decoder,
**args)
trainer = Trainer(
data=data,
model=model,
seq_length=seq_lengths,
device=device,
**args
)
if args['prob_dec'] == 'GAN':
model_d = WasDiscriminator(**args)
trainer = AdvTrainer(
data=data,
model_g=model,
model_d=model_d,
seq_length=seq_lengths,
device=device,
**args
)
trainer.train()
trainer.final_test(n=1)
trainer.plot_similarity('type_similarity_sof')