-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathscript.py
53 lines (53 loc) · 2.26 KB
/
script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def predict():
import sys
import argparse
opt = argparse.Namespace(checkpoint='./models/gmm_final.pth',
data_root='Database',
out_dir='output/first',
name='GMM',
batch_size=16,
n_worker=4,
gpu_id='0',
log_freq=100,
radius=5,
fine_width=192,
fine_height=256,
grid_size=5)
from run_gmm import run, GMM, GMMDataset, load_checkpoint, DataLoader, torch
model = GMM(opt)
load_checkpoint(model, opt.checkpoint)
# model.cuda()
model.eval()
print('Run on {} data'.format("VAL"))
dataset = GMMDataset(opt, "val", data_list='val_pairs.txt', train=False)
dataloader = DataLoader(dataset, batch_size=opt.batch_size,
num_workers=opt.n_worker, shuffle=False)
with torch.no_grad():
run(opt, model, dataloader, "val")
print('Successfully completed')
opt = argparse.Namespace(checkpoint='./models/tom_final.pth',
data_root='Database',
out_dir='output/second',
name='TOM',
batch_size=16,
n_worker=4,
gpu_id='0',
log_freq=100,
radius=5,
fine_width=192,
fine_height=256,
grid_size=5)
from run_tom import run, UnetGenerator, nn, load_checkpoint, TOMDataset, DataLoader
model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d)
load_checkpoint(model, opt.checkpoint)
# model.cuda()
model.eval()
mode = 'val'
print('Run on {} data'.format(mode.upper()))
dataset = TOMDataset(opt, mode, data_list=mode+'_pairs.txt', train=False)
dataloader = DataLoader(
dataset, batch_size=opt.batch_size, num_workers=opt.n_worker, shuffle=False)
with torch.no_grad():
run(opt, model, dataloader, mode)
print('Successfully completed')
# Resive image code