-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_SDF_Winding.py
More file actions
120 lines (94 loc) · 4.32 KB
/
train_SDF_Winding.py
File metadata and controls
120 lines (94 loc) · 4.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import yaml
import time
import torch
import importlib
import numpy as np
import os.path as osp
from torch.backends import cudnn
from torch.utils.tensorboard import SummaryWriter
from trainers.helper import load_pts, comp_winding
from trainers.standard_utils import AverageMeter, dict2namespace, load_imf
def get_args(input_config):
# parse config file
with open(input_config, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config = dict2namespace(config)
# Create log_name
logname = config.log_name
config.log_name = "logs/" + logname + "/SDF_step"
config.save_dir = "logs/" + logname + "/SDF_step"
config.log_dir = "logs/" + logname + "/SDF_step"
os.makedirs(osp.join(config.log_dir, 'config'))
with open(osp.join(config.log_dir, "config", "config.yaml"), "w") as outf:
yaml.dump(config, outf)
return config
def main_worker(cfg, O_pts_path):
# basic setup
cudnn.benchmark = True
writer = SummaryWriter(log_dir=cfg.log_name)
trainer_lib = importlib.import_module(cfg.trainer.type)
trainer = trainer_lib.Trainer(cfg)
start_epoch = 0
start_time = time.time()
print("Start epoch: %d End epoch: %d" % (start_epoch, cfg.trainer.epochs + start_epoch))
step = 0
duration_meter = AverageMeter("Duration")
loader_meter = AverageMeter("Loader time")
best_val = np.inf
#load pointcloud and input normals
data = np.loadtxt(O_pts_path)
points = data[:, :3]
normal_data = data[:, 3:6]
### load heat step networks
near_net,_ = load_imf(cfg.input.near_path)
if (cfg.input.far_path != "None"):
far_net,_ = load_imf(cfg.input.far_path)
else: far_net = None
kappa = (3/5)*near_net(torch.from_numpy(points).float().cuda()).max().cpu().detach().numpy()
#use winding numbers for inner/outer regions
# note: requires surface normals
domain_bound = cfg.input.parameters.domain_bound
dataloader = comp_winding(net = near_net, tresh = kappa, N = normal_data, P = points, domainbound = domain_bound)
# move input points to cuda
points = torch.from_numpy(points).float().cuda()
### start actual training loop
for epoch in range(start_epoch, cfg.trainer.epochs + start_epoch):
# train for one epoch
loader_start = time.time()
batchnumber = 0
for inner_pts, outer_pts in dataloader:
inner_pts = inner_pts.squeeze(0) # shape [batch_size, 3]
outer_pts = outer_pts.squeeze(0) # shape [batch_size, 3]
loader_duration = time.time() - loader_start
loader_meter.update(loader_duration)
step = batchnumber + 1000 * epoch + 1
batchnumber += 1
# Compute loss and update
logs_info = trainer.update(cfg, input_points=points, near_net=near_net, far_net=far_net, kappa = kappa, inner_sample = inner_pts, outer_sample = outer_pts,
epoch=epoch, step=step)
if step % int(cfg.viz.log_freq) == 0 and int(cfg.viz.log_freq) > 0:
duration = time.time() - start_time
duration_meter.update(duration)
start_time = time.time()
print("Epoch %d Batch [%2d/%2d] Time [%3.2fs] Loading [%3.2fs] Loss %2.5f" %
(epoch, batchnumber, 1000, duration_meter.avg,
loader_meter.avg, logs_info['loss']))
trainer.log_train(logs_info, writer=writer, epoch=epoch, step=step, visualize=False)
loader_start = time.time()
val_loss = trainer.validate(cfg, input_points=points, near_net=near_net, far_net=far_net, kappa = kappa, inner_sample = inner_pts, outer_sample = outer_pts, epoch = epoch, writer = writer)['loss']
if(val_loss < best_val):
trainer.save_best_val(epoch, step)
best_val = val_loss
trainer.sch.step(val_loss)
if (epoch + 1) % int(cfg.viz.save_freq) == 0 and \
int(cfg.viz.save_freq) > 0:
trainer.save(epoch=epoch, step=step, vis = False)
trainer.save(epoch=epoch, step=step, vis = True)
writer.close()
def run_training(input_config, O_pts_path):
# collect config settings and start training of SDF step
cfg = get_args(input_config)
print("Configuration:")
print(cfg)
main_worker(cfg, O_pts_path)