-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathtrain.py
182 lines (153 loc) · 6.18 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import hydra
import torch
from torch.utils.tensorboard import SummaryWriter
from omegaconf import DictConfig, OmegaConf
from loguru import logger
from utils.misc import compute_model_dim
from utils.io import mkdir_if_not_exists
from utils.plot import Ploter
from datasets.base import create_dataset
from datasets.misc import collate_fn_general, collate_fn_squeeze_pcd_batch
from models.base import create_model
from models.visualizer import create_visualizer
def train(cfg: DictConfig) -> None:
""" training portal
Args:
cfg: configuration dict
"""
if cfg.gpu is not None:
device = f'cuda:{cfg.gpu}'
else:
device = 'cpu'
## prepare dataset for train and test
datasets = {
'train': create_dataset(cfg.task.dataset, 'train', cfg.slurm),
}
if cfg.task.visualizer.visualize:
datasets['test_for_vis'] = create_dataset(cfg.task.dataset, 'test', cfg.slurm, case_only=True)
for subset, dataset in datasets.items():
logger.info(f'Load {subset} dataset size: {len(dataset)}')
if cfg.model.scene_model.name == 'PointTransformer':
collate_fn = collate_fn_squeeze_pcd_batch
else:
collate_fn = collate_fn_general
dataloaders = {
'train': datasets['train'].get_dataloader(
batch_size=cfg.task.train.batch_size,
collate_fn=collate_fn,
num_workers=cfg.task.train.num_workers,
pin_memory=True,
shuffle=True,
),
}
if 'test_for_vis' in datasets:
dataloaders['test_for_vis'] = datasets['test_for_vis'].get_dataloader(
batch_size=cfg.task.test.batch_size,
collate_fn=collate_fn,
num_workers=cfg.task.test.num_workers,
pin_memory=True,
shuffle=True,
)
## create model and optimizer
model = create_model(cfg, slurm=cfg.slurm, device=device)
model.to(device=device)
params = []
nparams = []
for n, p in model.named_parameters():
if p.requires_grad:
params.append(p)
nparams.append(p.nelement())
logger.info(f'add {n} {p.shape} for optimization')
params_group = [
{'params': params, 'lr': cfg.task.lr},
]
optimizer = torch.optim.Adam(params_group) # use adam optimizer in default
logger.info(f'{len(params)} parameters for optimization.')
logger.info(f'total model size is {sum(nparams)}.')
## create visualizer if visualize in training process
if cfg.task.visualizer.visualize:
visualizer = create_visualizer(cfg.task.visualizer)
## start training
step = 0
for epoch in range(0, cfg.task.train.num_epochs):
model.train()
for it, data in enumerate(dataloaders['train']):
for key in data:
if torch.is_tensor(data[key]):
data[key] = data[key].to(device)
optimizer.zero_grad()
data['epoch'] = epoch
outputs = model(data)
outputs['loss'].backward()
optimizer.step()
## plot loss
if (step + 1) % cfg.task.train.log_step == 0:
total_loss = outputs['loss'].item()
log_str = f'[TRAIN] ==> Epoch: {epoch+1:3d} | Iter: {it+1:5d} | Step: {step+1:7d} | Loss: {total_loss:.3f}'
logger.info(log_str)
for key in outputs:
val = outputs[key].item() if torch.is_tensor(outputs[key]) else outputs[key]
Ploter.write({
f'train/{key}': {'plot': True, 'value': val, 'step': step},
'train/epoch': {'plot': True, 'value': epoch, 'step': step},
})
step += 1
## save ckpt in epoch
if (epoch + 1) % cfg.save_model_interval == 0:
save_path = os.path.join(
cfg.ckpt_dir,
f'model_{epoch}.pth' if cfg.save_model_seperately else 'model.pth'
)
save_ckpt(
model=model, epoch=epoch, step=step, path=save_path,
save_scene_model=cfg.save_scene_model,
)
## test for visualize
if cfg.task.visualizer.visualize and (epoch + 1) % cfg.task.visualizer.interval == 0:
vis_dir = os.path.join(cfg.vis_dir, f'epoch{epoch+1:0>4d}')
visualizer.visualize(model, dataloaders['test_for_vis'], vis_dir)
def save_ckpt(model: torch.nn.Module, epoch: int, step: int, path: str, save_scene_model: bool) -> None:
""" Save current model and corresponding data
Args:
model: best model
epoch: best epoch
step: current step
path: save path
save_scene_model: if save scene_model
"""
saved_state_dict = {}
model_state_dict = model.state_dict()
for key in model_state_dict:
## if use frozen pretrained scene model, we can avoid saving scene model to save space
if 'scene_model' in key and not save_scene_model:
continue
saved_state_dict[key] = model_state_dict[key]
logger.info('Saving model!!!' + ('[ALL]' if save_scene_model else '[Except SceneModel]'))
torch.save({
'model': saved_state_dict,
'epoch': epoch, 'step': step,
}, path)
@hydra.main(version_base=None, config_path="./configs", config_name="default")
def main(cfg: DictConfig) -> None:
## compute modeling dimension according to task
cfg.model.d_x = compute_model_dim(cfg.task)
if os.environ.get('SLURM') is not None:
cfg.slurm = True # update slurm config
logger.remove(handler_id=0) # remove default handler
## set output logger and tensorboard
logger.add(cfg.exp_dir + '/runtime.log')
mkdir_if_not_exists(cfg.tb_dir)
mkdir_if_not_exists(cfg.vis_dir)
mkdir_if_not_exists(cfg.ckpt_dir)
writer = SummaryWriter(log_dir=cfg.tb_dir)
Ploter.setWriter(writer)
## Begin training progress
logger.info('Configuration: \n' + OmegaConf.to_yaml(cfg))
logger.info('Begin training..')
train(cfg) # training portal
## Training is over!
writer.close() # close summarywriter and flush all data to disk
logger.info('End training..')
if __name__ == '__main__':
main()