forked from wzmsltw/BSN-boundary-sensitive-network.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpem_jobs.py
More file actions
113 lines (91 loc) · 4.71 KB
/
pem_jobs.py
File metadata and controls
113 lines (91 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Run the jobs in this file.
Example commands:
python pem_jobs.py
"""
from copy import deepcopy
import os
import re
import sys
from run_on_cluster import fb_run_batch
from tem_jobs import run as temrun
email = 'cinjon@nyu.edu'
code_directory = '/private/home/cinjon/Code/BSN-boundary-sensitive-network.pytorch'
gymnastics_anno_directory = '/private/home/cinjon/Code/BSN-boundary-sensitive-network.pytorch/data/gymnastics_annotations'
thumos_anno_directory = '/private/home/cinjon/Code/BSN-boundary-sensitive-network.pytorch/data/thumos14_annotations'
base_dir = '/checkpoint/cinjon/spaceofmotion/bsn'
checkpoint_path = os.path.join(base_dir, 'checkpoint', 'pem')
tem_dir = os.path.join(base_dir, 'teminf')
tem_results_dir = os.path.join(tem_dir, 'results')
pgm_proposals_dir = os.path.join(base_dir, 'pgmprops')
pgm_feats_dir = os.path.join(base_dir, 'pgmfeats')
regex = re.compile('.*(\d{5}).*')
func = fb_run_batch
num_gpus = 1 # NOTE
def run(find_counter=None):
counter = 1622 # NOTE: adjust each time 451, 715, 750, 782, 814, 854, 950, 1054, 1150, 1382
check = 0
for tem_results_subdir in sorted(os.listdir(tem_results_dir)):
# if counter - start_counter > 100:
# print('Stopping at %d' % counter)
# break
print(tem_results_subdir)
_counter = int(regex.match(tem_results_subdir).groups()[0])
job = temrun(find_counter=_counter)
if type(job) == tuple:
job = job[1]
for key in list(job.keys()):
if key.startswith('tem'):
del job[key]
name = job['name']
for ckpt_subdir in os.listdir(os.path.join(tem_results_dir, tem_results_subdir)):
_job = deepcopy(job)
if 'thumos' in _job['dataset']:
_job['video_anno'] = os.path.join(_job['video_info'], 'thumos_anno_action.json')
elif 'gymnastics' in _job['dataset']:
_job['video_anno'] = os.path.join(_job['video_info'], 'gymnastics_anno_action.sep052019.json')
elif 'activitynet' in _job['dataset']:
_job['video_anno'] = '/private/home/cinjon/Code/BSN-boundary-sensitive-network.pytorch/data/activitynet_annotations/anet_anno_action.json'
_job['pgm_proposals_dir'] = os.path.join(pgm_proposals_dir, tem_results_subdir, ckpt_subdir)
_job['pgm_features_dir'] = os.path.join(pgm_feats_dir, tem_results_subdir, ckpt_subdir)
_job['module'] = 'PEM'
_job['mode'] = 'train'
_job['pem_compute_loss_interval'] = 1
_job['pem_epoch'] = 40
_job['pem_do_index'] = True
if _job['dataset'] != 'activitynet':
_job['video_info'] = os.path.join(_job['video_info'], 'Full_Annotation.csv')
_job['checkpoint_path'] = checkpoint_path
subname = tem_results_subdir.split(str(_counter))[1]
subname = '%05d%s' % (_counter, subname)
_job['name'] = '%s.%s' % (subname, ckpt_subdir)
_job['num_gpus'] = num_gpus
_job['num_cpus'] = num_gpus * 10
_job['gb'] = 64 * num_gpus
_job['time'] = 6
_job['pem_feat_dim'] = 48
_job['pem_batch_size'] = int(400 / num_gpus)
for pem_training_lr in [0.01]:
for pem_weight_decay in [0.0, 1e-4]:
for pem_l2_loss in [0.0, 0.000025]:
if pem_weight_decay > 0 and pem_l2_loss > 0:
continue
if pem_weight_decay == pem_l2_loss == 0.0:
continue
for milestones in ['10,30', '10,20']:
for pem_step_gamma in [0.1, 0.5]:
counter += 1
__job = deepcopy(_job)
__job['pem_training_lr'] = pem_training_lr
__job['pem_weight_decay'] = pem_weight_decay
__job['pem_l2_loss'] = pem_l2_loss
__job['pem_lr_milestones'] = milestones
__job['pem_step_gamma'] = pem_step_gamma
__job['name'] = '%s-%05d' % (_job['name'], counter)
check += 1
if not find_counter:
func(__job, counter, email, code_directory)
elif counter == find_counter:
return __job
print(counter, check, check // 8) # ended w 782, 814, 854, 950, 1054, 1150, 1382, 1486, 1622
if __name__ == '__main__':
run()