-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathloadMel.lua
127 lines (103 loc) · 3.86 KB
/
loadMel.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
-- @author Sachin Mehta
--- ---
-- File to load Pascal Context Dataset and Cache it
-- @module loadPascalContext
require 'image'
-- load the training and test files
local trainFile = opt.datapath .. '/train.txt'
local valFile = opt.datapath .. '/test.txt'
local classesName = { 'Background', '1', '2', '3', '4', '5', '6', '7'}
local classes = #classesName
--- ----
-- Function to check if file exists or not
-- @function [parent=#loadPascalContext] check_file
-- @param #string name File name
-- @return #boolean Boolean indicating whether file exists or not
--
local function check_file(name)
local f=io.open(name,"r")
if f~=nil then
io.close(f)
return true
else
return false
end
end
trainImFileList = {}
trainIm1FileList = {}
trainLblFileList = {}
labelAddVal = 0
local histClasses = torch.Tensor(classes):fill(0)
--parse the training data
if not check_file(trainFile) then
print('Training file does not exist: ' .. trainFile)
os.exit()
else
lineNo = 0
for line in io.lines(trainFile) do
local col1, col2, col3 = line:match("([^,]+),([^,]+),([^,]+)")
trainImFileList[lineNo] =opt.datapath .. col1
trainLblFileList[lineNo] = opt.datapath .. col2
trainIm1FileList[lineNo] = opt.datapath .. col3
local labelIm = image.load(trainLblFileList[lineNo], 1, 'byte')
--scale the label image using simple interpolation
labelIm = image.scale(labelIm, opt.imWidth, opt.imHeight, 'simple')
labelIm:add(labelAddVal)
labelIm[labelIm:eq(0)] = 1
histClasses = histClasses + torch.histc(labelIm:float(), classes, 1, classes)
assert(torch.max(labelIm) <= classes and torch.min(labelIm) > 0, 'Label values should be between 1 and number of classes: max ' .. torch.max(labelIm) .. ' min: ' .. torch.min(labelIm))
lineNo = lineNo + 1
end
assert(table.getn(trainImFileList) == table.getn(trainLblFileList), 'Number of images and labels are not equal')
end
--parse the validation data
valImFileList = {}
valLblFileList = {}
valIm1FileList = {}
--
if not check_file(valFile) then
print('Validation file does not exist: ' .. valFile)
os.exit()
else
lineNo = 0
for line in io.lines(valFile) do
local col1, col2, col3 = line:match("([^,]+),([^,]+),([^,]+)")
valImFileList[lineNo] =opt.datapath .. col1
valLblFileList[lineNo] = opt.datapath .. col2
valIm1FileList[lineNo] = opt.datapath .. col3
local labelIm = image.load(valLblFileList[lineNo], 1, 'byte')
--scale the label image using simple interpolation
labelIm = image.scale(labelIm, opt.imWidth, opt.imHeight, 'simple')
labelIm:add(labelAddVal)
labelIm[labelIm:eq(0)] = 1
assert(torch.max(labelIm) <= classes and torch.min(labelIm) > 0, 'Label values should be between 1 and number of classes: max ' .. torch.max(labelIm) .. ' min: ' .. torch.min(labelIm))
lineNo = lineNo + 1
end
assert(table.getn(valImFileList) == table.getn(valLblFileList), 'Number of images and labels are not equal')
end
local normHist = histClasses / histClasses:sum()
local classWeights = torch.Tensor(classes):fill(1)
for i = 1, classes do
if histClasses[i] < 1 then
print("Class " .. tostring(i) .. " not found")
classWeights[i] = 0
else
classWeights[i] = 1 / (torch.log(1.02 + normHist[i]))
end
end
--cache the training and validation data information
dataCache = {}
dataCache.trainIm = trainImFileList
dataCache.trainIm1 = trainIm1FileList
dataCache.trainlbl = trainLblFileList
dataCache.valIm = valImFileList
dataCache.valIm1 = valIm1FileList
dataCache.vallbl = valLblFileList
dataCache.classes = classes
dataCache.labelAddVal = labelAddVal
dataCache.classWeight = classWeights
if not paths.dirp(opt.cacheDir) and not paths.mkdir(opt.cacheDir) then
cmd:error('Error: Unable to create a cache directory: '.. opt.cacheDir .. '\n')
end
--save the details about the dataset
torch.save(opt.dataCacheFileName, dataCache)