Open
Description
Hi, I modified the code fb.resnet.torch/dataloader.lua in order to read data triplet by triplet. But I encountered with an confusing error:
FATAL THREAD PANIC: (write) /home/haha/torch/install/share/lua/5.1/torch/File.lua:141:
Unwritable object <userdata> at <?>.callback.self.resnet.DataLoader.threads.__gc__
Below is my code...
function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local perm = torch.randperm(size)
local tripletList = self:genTriplet()
local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops, tripletList)
local sz = indices:size(1) * 3 --should be 3 times as previous, since now it is triplet
local batch, imageSize
local target = torch.IntTensor(sz)
for i, idx in ipairs(indices:totable()) do
local idx_anchor = tripletList[idx][1]
local idx_positive = tripletList[idx][2]
local idx_negative = tripletList[idx][3]
local sample_anchor = _G.dataset:get(idx_anchor) --get images
local sample_positive = _G.dataset:get(idx_positive)
local sample_negative = _G.dataset:get(idx_negative)
local input_anchor = _G.preprocess(sample_anchor.input)
local input_positive = _G.preprocess(sample_positive.input)
local input_negative = _G.preprocess(sample_negative.input)
if not batch then
imageSize = input_anchor:size():totable()
if nCrops > 1 then table.remove(imageSize, 1) end
batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize))
end
batch[(i-1)*2 + 1]:copy(input_anchor)
batch[(i-1)*2 + 2]:copy(input_positive)
batch[self.samples*self.blocks + i]:copy(input_negative)
target[(i-1)*2 + 1] = sample_anchor.target
target[(i-1)*2 + 2] = sample_positive.target
target[self.samples*self.blocks + i] = sample_negative.target
end
collectgarbage()
return {
input = batch:view(sz * nCrops, table.unpack(imageSize)),
target = target,
}
end,
function(_sample_)
-- print ('WHAT????')
sample = _sample_
end,
indices,
self.nCrops,
tripletList
)
idx = idx + batchSize
end
end
local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end
return loop
end
Below is the original code:
function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local perm = torch.randperm(size)
local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops)
local sz = indices:size(1)
local batch, imageSize
local target = torch.IntTensor(sz)
for i, idx in ipairs(indices:totable()) do
local sample = _G.dataset:get(idx)
local input = _G.preprocess(sample.input)
if not batch then
imageSize = input:size():totable()
if nCrops > 1 then table.remove(imageSize, 1) end
batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize))
end
batch[i]:copy(input)
target[i] = sample.target
end
collectgarbage()
return {
input = batch:view(sz * nCrops, table.unpack(imageSize)),
target = target,
}
end,
function(_sample_)
sample = _sample_
end,
indices,
self.nCrops
)
idx = idx + batchSize
end
end
local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end
return loop
end
Metadata
Metadata
Assignees
Labels
No labels