-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
76 lines (68 loc) · 1.83 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#coding:utf8
import warnings
class DefaultConfig(object):
env = 'default'
model = 'EVQA'
train_params = {
'learning_rate': 1e-3,
'lr_decay_n_iters': 2000,
'lr_decay_rate': 0.8,
'max_epoches': 1000,
'early_stopping': 10,
'cache_dir': './model/tsn/evqa/',
'display_batch_interval': 20,
'summary_interval': 10,
'evaluate_interval': 5,
'saving_interval': 1000,
'epoch_reshuffle': True,
'weight_decay':1e-4
}
model_params = {
'use_q_len': True,
'model_name': 'qvec_all_att_model',
'lstm_dim': 384,
'ref_dim':300,
'ques_embed_dim': 100, # original: 300
'attention_dim': 256,
'regularization_beta': 1e-7,
'use_notifier': True,
'dropout_prob': 0.6,
'boundary_dim': 128,
'decode_dim': 256
}
data_params = {
# general
'dataset': 'tsn',
'batch_size': 64,
'n_types': 5,
'n_words': 9611, # change all time
# feature
'use_frame': True,
'input_video_dim': 404,
'max_n_frames': 240,
# question answer
'max_n_q_words': 20,
'max_n_a_words': 10,
'use_qvec': True,
'input_ques_dim': 300,
# path
'wordvec_path': '/home/szh/mm_2018/data/tsn_score/caption/wordvec.npy',
'word_dict_path': '/home/szh/mm_2018/data/tsn_score/caption/worddict.pkl',
'train_json': '/home/szh/mm_2018/data/tsn_score/train_clean.json',
'val_json': '/home/szh/mm_2018/data/tsn_score/val_clean.json',
'test_json': '/home/szh/mm_2018/data/tsn_score/test_clean.json',
'feat_path': '/home/szh/mm_2018/data/tsn_score/feat/',
'flow_path': '/home/szh/mm_2018/data/tsn_score/flow/'
}
def parse(self,kwargs):
'''
update parameters according to kwargs
'''
for k,v in kwargs.items():
if not hasattr(self,k):
warnings.warn("Warning: opt has not attribute %s"%k)
print("user config:")
for k,v in self.__class__.__dict__.items():
if not k.startswith('__'):
print(k,getattr(self,k))
opt = DefaultConfig()