Skip to content

Commit 93861b5

Browse files
committed
Merge pull request #6 from pakozm/devel
Devel
2 parents c292ec9 + a31a23f commit 93861b5

File tree

4 files changed

+54
-22
lines changed

4 files changed

+54
-22
lines changed

mapreduce/examples/April-ANN/common.lua

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,17 @@ local mapfn = function(key, value, emit)
9797
local train_func = deserialize_from_gridfs(gridfs, assert(conf.train_func))
9898
local trainer = train_func:get_state_table().last
9999
conf:read_only(true)
100-
local weight_grads,loss_matrix = compute_gradients_and_loss(trainer,
101-
key, value,
102-
conf)
100+
local weight_grads,loss_matrix,bunch_size =
101+
compute_gradients_and_loss(trainer, key, value, conf)
103102
conf:read_only(false)
103+
assert(weight_grads and loss_matrix and bunch_size,
104+
"compute_gradients_and_loss had to return gradients, loss_matrix and bunch_size")
104105
for name,grads in pairs(weight_grads) do
105106
serialize_and_map_emit(name,
106-
{ grads, trainer:weights(name):get_shared_count() },
107+
{
108+
grads,
109+
trainer:weights(name):get_shared_count()*bunch_size
110+
},
107111
emit)
108112
end
109113
serialize_and_map_emit(TR_LOSS_KEY, loss_matrix, emit)
@@ -129,7 +133,7 @@ local reducefn = function(key, values, emit)
129133
end
130134
serialize_and_red_emit({ loss:get_accum_loss() }, emit)
131135
else
132-
-- accumulate here the shared count
136+
-- accumulate gradients and shared count
133137
local t = deserialize_emitted_value(values[1])
134138
local gradient = t[1]
135139
local counts = t[2]
@@ -165,10 +169,12 @@ local finalfn = function(pairs_iterator)
165169
tr_loss_mean = value[1]
166170
tr_loss_var = value[2]
167171
else
172+
local N = value[2] if not N or N==0 then N=1 end
173+
if params.smooth_gradients then
174+
-- gradients smoothing
175+
value[1]:scal( 1.0/math.sqrt(N) )
176+
end
168177
weight_grads[key] = value[1]
169-
local w = trainer:weights(key)
170-
w:reset_shared_count()
171-
w:add_to_shared_count(value[2])
172178
end
173179
end
174180
assert(tr_loss_mean)
@@ -214,6 +220,7 @@ local make_map_reduce_task_table = function(t)
214220
user_taskfn = { mandatory = true, type_match="function" },
215221
user_finalfn = { mandatory = true, type_match="function" },
216222
generate_new_trainer_and_train_func = { mandatory = true, type_match="function" },
223+
smooth_gradients = { mandatory = false, default = true },
217224
}, t)
218225
--
219226
dbname = params.dbname

mapreduce/examples/April-ANN/init.lua

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ local NUM_REDUCERS = 10
77
local EXP_DBHOST = "localhost"
88
local EXP_DBNAME = "exp_digits"
99

10-
local bunch_size = 32
10+
local bunch_size = 128
1111
local weights_random = random(1234)
12-
local description = "256 inputs 256 tanh 128 tanh 10 log_softmax"
12+
local description = "256 inputs 128 tanh 10 log_softmax"
1313
local inf = -1
1414
local sup = 1
1515
local shuffle_random = random() -- TOTALLY RANDOM FOR EACH WORKER
16-
local learning_rate = 0.005
16+
local learning_rate = 0.01
1717
local momentum = 0.02
1818
local weight_decay = 1e-04
1919
local max_epochs = 40
@@ -77,7 +77,7 @@ local make_load_matrix = function(value)
7777
end
7878
end
7979

80-
local make_load_dataset = function(mat)
80+
local make_load_dataset = function(mat,m2)
8181
return function()
8282
local train_input = dataset.matrix(mat,
8383
{
@@ -122,10 +122,11 @@ local make_load_dataset = function(mat)
122122
end
123123
end
124124

125-
-- receives the persistent table in read-only mode as last argument
125+
-- receives a trainer, key,value pair and the persistent table in read-only mode
126+
-- as last argument; returns gradients, loss_matrix and bunch_size
126127
local compute_gradients_and_loss = function(trainer, key, value, conf)
127-
local mat = cached(value, make_load_matrix(value), mat_cache)
128-
local ds_tbl = cached(value, make_load_dataset(mat), ds_cache)
128+
local mat = cached(value, make_load_matrix(value), mat_cache)
129+
local ds_tbl = cached(value, make_load_dataset(mat,m2), ds_cache)
129130
local in_ds = ds_tbl.train_input
130131
local out_ds = ds_tbl.train_output
131132
local bunch_tbl = {}
@@ -136,15 +137,16 @@ local compute_gradients_and_loss = function(trainer, key, value, conf)
136137
local target = out_ds:getPatternBunch(bunch_tbl)
137138
local grads,tr_loss,tr_loss_matrix = trainer:compute_gradients_step(input,
138139
target)
139-
return grads,tr_loss_matrix
140+
return grads,tr_loss_matrix,bunch_size
140141
end
141142

142-
-- receives the persistent table in read-only mode as last argument
143+
-- receives a trainer and the persistent table in read-only mode as last
144+
-- argument; returns the validation loss mean and variance
143145
local compute_validation_loss = function(trainer, conf)
144146
util.omp_set_num_threads(4)
145147
local value = "misc/digits.png"
146-
local mat = cached(value, make_load_matrix(value), mat_cache)
147-
local ds_tbl = cached(value, make_load_dataset(mat), ds_cache)
148+
local mat = cached(value, make_load_matrix(value), mat_cache)
149+
local ds_tbl = cached(value, make_load_dataset(mat,m2), ds_cache)
148150
local in_ds = ds_tbl.val_input
149151
local out_ds = ds_tbl.val_output
150152
local va_loss_mean,va_loss_var = trainer:validate_dataset{
@@ -155,7 +157,8 @@ local compute_validation_loss = function(trainer, conf)
155157
return va_loss_mean,va_loss_var
156158
end
157159

158-
-- the last argument is the persistent table (allows read/write operations)
160+
-- receives a train_func instance and the persistent table (allows read/write
161+
-- operations)
159162
local user_finalfn = function(train_func, conf)
160163
print(train_func:get_state_string())
161164
train_func:save("best_func.lua")
@@ -173,4 +176,5 @@ return common.make_map_reduce_task_table {
173176
generate_new_trainer_and_train_func = generate_new_trainer_and_train_func,
174177
compute_gradients_and_loss = compute_gradients_and_loss,
175178
compute_validation_loss = compute_validation_loss,
179+
-- smooth_gradients = true, -- by default it is true
176180
}

mapreduce/task.lua

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,14 @@ function task:get_partition_args()
243243
return self.tbl.init_args
244244
end
245245

246-
-- JOB INTERFACE
246+
-- TASK INTERFACE
247+
248+
local cache_map_ids = {}
249+
local cache_inv_map_ids = {}
250+
function task.reset_cache()
251+
cache_map_ids = {}
252+
cache_inv_map_ids = {}
253+
end
247254

248255
-- workers use this method to load a new job in the caller object
249256
function task:take_next_job(tmpname)
@@ -265,6 +272,12 @@ function task:take_next_job(tmpname)
265272
{ status = STATUS.BROKEN, },
266273
},
267274
}
275+
-- after first iteration, map jobs done previously will be taken if possible,
276+
-- reducing the overhead for loading data
277+
if self:get_iteration() > 1 and task_status == TASK_STATUS.MAP then
278+
query._id = { ["$in"] = cache_map_ids }
279+
if db:count(jobs_ns, query) == 0 then query._id = nil end
280+
end
268281
local set_query = {
269282
worker = utils.get_hostname(),
270283
tmpname = tmpname_summary(tmpname),
@@ -282,6 +295,13 @@ function task:take_next_job(tmpname)
282295
-- updated its data
283296
local job_tbl = db:find_one(jobs_ns, set_query)
284297
if job_tbl then
298+
if task_status == TASK_STATUS.MAP then
299+
local _id = job_tbl._id
300+
if not cache_inv_map_ids[_id] then
301+
cache_inv_map_ids[_id] = true
302+
table.insert(cache_map_ids, _id)
303+
end
304+
end
285305
local storage,path = self:get_storage()
286306
return task_status,job(self.cnn, job_tbl, task_status,
287307
self:get_fname(), self:get_args(),

mapreduce/worker.lua

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ local utils = require "mapreduce.utils"
3636
local task = require "mapreduce.task"
3737
local cnn = require "mapreduce.cnn"
3838

39-
-- PRIVATE FUNCTIONS
39+
-- PRIVATE FUNCTIONS AND PROPERTIES
4040

4141
-- executes the worker main loop; it runs querying the task object for new jobs
4242
local worker_execute = function(self)
@@ -92,6 +92,7 @@ local worker_execute = function(self)
9292
ntasks = ntasks + 1
9393
job_done = false
9494
job.reset_cache()
95+
task.reset_cache()
9596
end
9697
if ntasks < MAX_TASKS then
9798
print(string.format("# WAITING...\tntasks: %d/%d\tit: %d/%d\tsleep: %.1f",

0 commit comments

Comments
 (0)