-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest.lua
102 lines (77 loc) · 3.25 KB
/
test.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
-- @author Sachin Mehta
require 'optim'
require 'xlua'
require 'image'
confusionMatTest = optim.ConfusionMatrix(confClasses)
--save the training error to files
local valLogger = optim.Logger(paths.concat(opt.snap, 'error_test_' .. opt.resumeEpoch .. '.log'))
local inputs = torch.Tensor(1, 3, opt.imHeight, opt.imWidth)
local targets = torch.Tensor(1, opt.imHeight-128, opt.imWidth-128)
inputs = inputs:cuda()
targets = targets:cuda()
--- ----
-- Function to test the network
-- @function [parent=#test] test
-- @param #number epoch Epoch number
-- @param #table dataset Table containing the information about the dataset
local function test(epoch, dataset)
local time = sys.clock()
model:evaluate()
valSize = table.getn(dataset.valIm)
validationErr = 0
for i = 1,valSize do
xlua.progress(i, valSize)
local rgbImg = image.load(dataset.valIm[i]):float()
--rgbImg = image.scale(rgbImg, opt.imWidth, opt.imHeight)
-- We learn the mean and STD using Batch Normalization.
-- If you want to use fixed mean and std, uncomment the following lines and change the mean (0.5) and std (1) values
--rgbImg[1]:add(-0.5)
--rgbImg[2]:add(-0.5)
--rgbImg[3]:add(-0.5)
--rgbImg[1]:div(1)
--rgbImg[2]:div(1)
--rgbImg[3]:div(1)
local lblImg = image.load(dataset.vallbl[i], 1, 'byte'):float()
--lblImg = image.scale(lblImg, opt.imWidth, opt.imHeight, 'simple')
lblImg:add(dataset.labelAddVal)
lblImg[lblImg:eq(0)] = 1
lblImg[lblImg:gt(opt.classes)] = 1
local start_dim = opt.cropStart --64
local end_dim = opt.cropEnd --256
lblImg = lblImg:narrow(2, start_dim, end_dim)
lblImg = lblImg:narrow(3, start_dim, end_dim)
inputs[1] = rgbImg
targets[1] = lblImg
local output = model:forward(inputs)
local err = criterion:forward(output,targets)
validationErr = validationErr + err
local _, pred = output:max(2)
confusionMatTest:batchAdd(pred:view(-1), targets:view(-1))
end
time = (sys.clock() - time)/valSize
validationErr = validationErr / valSize
print("==> time to test 1 sample = " .. (time*1000) .. 'ms')
print('Validation Error: ' .. validationErr)
valLogger:add{['Validation Error '] = validationErr,
['Epoch'] = epoch}
--save the model
--save the confusion matrix
local filenameCon = paths.concat(opt.snap, 'con-' .. epoch .. '.txt')
print('saving confusion matrix: ' .. filenameCon)
local fileCon = io.open(filenameCon, 'w')
fileCon:write("--------------------------------------------------------------------------------\n")
fileCon:write("Training:\n")
fileCon:write("================================================================================\n")
fileCon:write(tostring(confusionMatTrain))
fileCon:write("\n--------------------------------------------------------------------------------\n")
fileCon:write("Testing:\n")
fileCon:write("================================================================================\n")
fileCon:write(tostring(confusionMatTest))
fileCon:write("\n--------------------------------------------------------------------------------")
fileCon:close()
print('\n')
confusionMatTest:zero()
confusionMatTrain:zero()
collectgarbage()
end
return test