diff --git a/torch_modules/LSTM_LN.lua b/torch_modules/LSTM_LN.lua index ac6c1d4..c92531b 100644 --- a/torch_modules/LSTM_LN.lua +++ b/torch_modules/LSTM_LN.lua @@ -9,18 +9,20 @@ function LSTM.lstm(inputSize, hiddenSize) local prev_c = nn.Identity()() local prev_h = nn.Identity()() - function new_input_sum_bias(bias) + function new_input_sum() -- transforms input - local i2h = nn.Linear(inputSize, hiddenSize)(x) + local i2h = nn.Linear(inputSize, hiddenSize, false)(x) -- no bias -- transforms previous timestep's output - local h2h = nn.Linear(hiddenSize, hiddenSize)(prev_h) - return nn.CAddTable()({i2h, h2h}) + local h2h = nn.Linear(hiddenSize, hiddenSize, false)(prev_h) -- no bias + + -- add a bias term here + return nn.Add(hiddenSize)(nn.CAddTable()({i2h, h2h})) end - local in_gate = nn.Sigmoid()(new_input_sum_bias(0)) - local forget_gate = nn.Sigmoid()(new_input_sum_bias(-4.)) - local out_gate = nn.Sigmoid()(new_input_sum_bias(0)) - local in_transform = nn.Tanh()(new_input_sum_bias(0)) + local in_gate = nn.Sigmoid()(new_input_sum()) + local forget_gate = nn.Sigmoid()(new_input_sum()) + local out_gate = nn.Sigmoid()(new_input_sum()) + local in_transform = nn.Tanh()(new_input_sum()) local next_c = nn.CAddTable()({ nn.CMulTable()({forget_gate, prev_c}), @@ -34,4 +36,3 @@ function LSTM.lstm(inputSize, hiddenSize) end return LSTM -