Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNN doesn't work as expected #2185

Open
alerem18 opened this issue Feb 9, 2023 · 31 comments
Open

RNN doesn't work as expected #2185

alerem18 opened this issue Feb 9, 2023 · 31 comments

Comments

@alerem18
Copy link

alerem18 commented Feb 9, 2023

i tried to implement a RNN MODEL to classify Mnist Dataset but i get an accuracy around 40-50% even with running it for more than 20 epochs, while in pytorch, i'll get an accuracy upto 90% after just 4-5 epochs

here is my code:

using Flux
using Flux: onehotbatch, onecold, params, gradient
using MLDatasets: MNIST
using Base.Iterators: partition
using TensorCast
using Statistics: mean
using Random: shuffle

#---------------------------------- DATA -------------------------------------
DATA_TRAIN = MNIST.traindata(Float32)
DATA_TEST = MNIST.testdata(Float32)

#-------------------------------- PREPROCESS DATA ------------------------------
@cast x_train[j][i, k] := DATA_TRAIN[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 60000
@cast x_test[j][i, k] := DATA_TEST[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 10000

# create onehotbatch for train label
y_train = onehotbatch(DATA_TRAIN[2], 0:9)
y_test = DATA_TEST[2]

#------------------------------ CONSTANTS ---------------------------------------
INPUT_DIM = size(x_train[1], 1)
OUTPUT_DIM = 10 # number of classes
LR = 0.001 # learning rate
EPOCHS = 100
BATCH_SIZE = 1000
TOTAL_SAMPLES = size(x_train[1], 2)

#--------------------------------- BUILD MODEL -----------------------------------
struct RnnModel
  rnn
  fc
end

Flux.@functor RnnModel

# pass input thorough MODEL
function (m::RnnModel)(input_data)

  # warmup rnn
  [m.rnn(x) for x  input_data[1:end - 1]]

  # pass latest layer to fc layer
  m.fc(m.rnn(input_data[end]))
end

# build MODEL
model = RnnModel(
  Chain(RNN(INPUT_DIM, 128), relu, RNN(128, 64), relu, RNN(64, 32), relu),
  Chain(Dense(32, OUTPUT_DIM), softmax)
)

#----------------------------- HELPER FUNCTIONS --------------------------------------
loss_fn(x, y) = Flux.Losses.logitcrossentropy(model(x), y)
function accuracy(x, y)
  Flux.reset!(model)
  mean(onecold(model(x), 0:9) .== y)
end

θ = params(model) # model parameters to be updated during training
opt = Flux.ADAM(LR) # optimizer function

#---------------------------- RUN TRAINING ----------------------------------------------
for epoch  1:EPOCHS
  for idx  partition(1:TOTAL_SAMPLES, BATCH_SIZE)
    Flux.reset!(model)
    features = [x[:, idx] for x  x_train]
    labels = y_train[:, idx]
    gs = gradient(θ) do
      loss = loss_fn(features, labels)
      loss
    end

    # update model
    Flux.Optimise.update!(opt, θ, gs)
  end

  # evaluate model
  @info epoch
  @show accuracy(x_test, y_test)
end

what i'm doing wrong?

@ToucheSir
Copy link
Member

I'm surprised this works at all with the input format given. What does the PyTorch code look like and have you verified it's doing the same thing?

@alerem18
Copy link
Author

I'm surprised this works at all with the input format given. What does the PyTorch code look like and have you verified it's doing the same thing?

what should be the format?
don't look at softmax with logitcrossentropy, i typed it here wrongly,
it shouldn't be a vector of length seq_len with matrix of (features, batch_size)?

@alerem18
Copy link
Author

alerem18 commented Feb 12, 2023

pytorch is quite different, it got a shape of (batch_size, seq_len, features)
also i get much worse results by just reshape the data differently:
@cast x_train[i][j, k] := DATA_TRAIN[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 60000 @cast x_test[i][j, k] := DATA_TEST[1][i, j, k] # reshape to vector of size 28 with matrix of size 28 x 10000
the top reshapes will lead to a worse result

@ToucheSir
Copy link
Member

pytorch is quite different, it got a shape of (batch_size, seq_len, features)

Flux supports something very similar. This is why it's important to see the PyTorch code as well, I have a feeling this is not an apples-to-apples comparison.

@alerem18
Copy link
Author

alerem18 commented Feb 12, 2023

pytorch is quite different, it got a shape of (batch_size, seq_len, features)

Flux supports something very similar. This is why it's important to see the PyTorch code as well, I have a feeling this is not an apples-to-apples comparison.

pytorch implementation:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torchvision.datasets import MNIST
    from torchvision.transforms import ToTensor
    from torch.utils.data import DataLoader

    # ------------------------------ DATA -----------------------------------
    train_data = MNIST(train=True, root='data', transform=ToTensor())
    test_data = MNIST(train=False, root='data', transform=ToTensor())
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=1000, shuffle=True)


    # ---------------------------- MODEL --------------------------------------
    class RNN(nn.Module):
        def __init__(self, input_dim, output_dim):
            super(RNN, self).__init__()
            self.rnn = nn.RNN(input_dim, 128, batch_first=True)
            self.fc = nn.Linear(128, output_dim)

        def forward(self, x, h):
            x, h = self.rnn(x, h)
            x = F.relu(x)
            x = self.fc(x)
            # get last layer from rnn
            return x[:, -1, :], h

        def init_hidden(self, batch_size):
            return torch.zeros([1, batch_size, 128])


    # ----------------------- HELPER -----------------------------------
    # seq_len = 28, input_dim=28, num_classes=10
    model = RNN(input_dim=28, output_dim=10)
    loss_fn = nn.CrossEntropyLoss()  # includes softmax layer too so we don't need it in the model


    def accuracy(X, y):
        total_samples = X.shape[0]
        h = model.init_hidden(batch_size=total_samples)
        with torch.no_grad():
            pred_values, _ = model(X, h)
            return torch.sum(pred_values.max(1)[1] == y) / total_samples


    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # --------------------------------- TRAIN LOOP ------------------------
    for epoch in range(1, 11):
        for data in train_loader:
            features = data[0].squeeze(1) # convert (batch_size, 1, 28, 28) to (batch_size, 28, 28)
            h = model.init_hidden(batch_size=features.shape[0]) # hidden state
            labels = data[1]
            predicted_values, _ = model(features, h)
            loss = loss_fn(predicted_values, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # get test data
        test_data = next(iter(test_loader))
        test_features = test_data[0].squeeze(1) # convert (batch_size, 1, 28, 28) to (batch_size, 28, 28)
        test_labels = test_data[1]
        print(f"epoch : {epoch}\\10\taccuracy : {accuracy(test_features, test_labels)}")

epoch : 1\10 accuracy : 0.8370000123977661
epoch : 2\10 accuracy : 0.8920000195503235

of course i used smaller batch__size in Flux like 64, 32, but still the same result
i could reach to 74% accuracy in Flux using 30 epochs and Momentum Optimizer but in pytorch at the first epoch we have a high accuracy already

@ToucheSir
Copy link
Member

ToucheSir commented Feb 12, 2023

Thanks, will try to take a look over the next couple of days. One quick observation though:

# includes softmax layer too so we don't need it in the model

This is also true for Flux's logitcrossentropy. So you shouldn't have a softmax in your Flux model either, and I suspect that having it is hurting performance because of the redundancy with the loss function.

@alerem18
Copy link
Author

alerem18 commented Feb 12, 2023

yes i know, i've used logitcrossentropy without softmax, also softmax with crossentropy, but still same results
i just typed the julia code here wrongly

@skyleaworlder
Copy link
Contributor

I haven't run the code. But is there a possibility that model input is mistaken? 50% accuracy really reminds me of my one data-processing experience.

@alerem18
Copy link
Author

alerem18 commented Feb 13, 2023

I haven't run the code. But is there a possibility that model input is mistaken? 50% accuracy really reminds me of my one data-processing experience.

how i should prepare my data then?
mnist has a shape of 28 x 28 x 60000, if use second dimension as seq_len it's different from where i use first dimension as seq_len
while there shouldn't be any difference at all
and accuracy is not high for both inputs

@jeremiedb
Copy link
Contributor

I took a shot rewriting the model as I would have go implementing it. It results in a 91% right after first epch, batchsize=64.
Batchsize 1000 also works fine, just starts with accuracy of 77% but reaches 96% after 10 epochs.

using Flux
using Flux: onehotbatch, onecold, params, gradient
using MLDatasets: MNIST
using Base.Iterators: partition
using Statistics: mean
using Random: shuffle

#---------------------------------- DATA -------------------------------------
DATA_TRAIN = MNIST.traindata(Float32)
DATA_TEST = MNIST.testdata(Float32)

#-------------------------------- PREPROCESS DATA ------------------------------
x_train = [x for x in eachslice(DATA_TRAIN[1], dims=2)] # reshape to vector of size 28 with matrix of size 28 x 60000
x_test = [x for x in eachslice(DATA_TEST[1], dims=2)] # reshape to vector of size 28 with matrix of size 28 x 10000

# create onehotbatch for train label
y_train = onehotbatch(DATA_TRAIN[2], 0:9)
y_test = DATA_TEST[2]

#------------------------------ CONSTANTS ---------------------------------------
INPUT_DIM = size(x_train[1], 1)
OUTPUT_DIM = 10 # number of classes
LR = 0.001f0 # learning rate
EPOCHS = 10
BATCH_SIZE = 64
TOTAL_SAMPLES = size(x_train[1], 2)

#--------------------------------- BUILD MODEL -----------------------------------
model = Chain(
  RNN(INPUT_DIM => 128, relu),
  Dense(128, OUTPUT_DIM)
)

#----------------------------- HELPER FUNCTIONS --------------------------------------
function loss_fn_2(m, x, y)
  out = [m(xi) for xi in x] # generate output for each of the 28 timesteps
  Flux.Losses.logitcrossentropy(out[end], y) # compute loss based on predictions of the latest timestep
end

function accuracy_eval(m, x, y)
  Flux.reset!(m)
  out = [m(xi) for xi in x]
  mean(onecold(out[end], 0:9) .== y)
end  

θ = params(model) # model parameters to be updated during training
opt = Flux.ADAM(LR) # optimizer function

#---------------------------- RUN TRAINING ----------------------------------------------
for epoch  1:EPOCHS
  for idx  partition(1:TOTAL_SAMPLES, BATCH_SIZE)
    features = [x[:, idx] for x  x_train]
    labels = y_train[:, idx]
    Flux.reset!(model)
    gs = gradient(θ) do
      loss = loss_fn_2(model, features, labels)
    end
    # update model
    Flux.Optimise.update!(opt, θ, gs)
  end

  # evaluate model
  @info epoch
  @show accuracy_eval(model, x_test, y_test)
end

I think the data preprocessing was done fine (I just dropped the TensorCast dependency as I got an issue and felt simpler not using it).

I'm really unclear what went wrong with your implementation. It's really just a speculation, but perhaps the gradients didn't get propagated through the following part:

  [m.rnn(x) for x ∈ input_data[1:end - 1]]
  m.fc(m.rnn(input_data[end]))

as there's no explicit passing of the of the inital computation to the second. Again, just a wild guess here.

@alerem18
Copy link
Author

alerem18 commented Feb 15, 2023

your code works but i really don't know why my code isn't working if the data preprocessing is the same
i tried a different implementation similar to yours for calculation loss

 using Flux
 using Flux: onehotbatch, onecold, params, gradient
 using MLDatasets: MNIST
 using Base.Iterators: partition, product
 using TensorCast
 using Statistics: mean
 using Random: shuffle
 using StatsBase
 using ChainRulesCore, Zygote
 ChainRulesCore.@non_differentiable foreach(f, ::Tuple{})
 Zygote.refresh()
# ---------------------------------- DATA -------------------------------------
 TRAIN_DATA, TRAIN_LABELS = MNIST.traindata(Float32)
 TEST_DATA, TEST_LABELS = MNIST.testdata(Float32)
 TRAIN_LABELS = onehotbatch(TRAIN_LABELS, 0:9)
 # convert 3d arrays to vector of 2d arrays
 @cast TRAIN_FEATURES[i][j, k] := TRAIN_DATA[i, j, k]
 @cast TEST_FEATURES[i][j, k] := TEST_DATA[i, j, k]

 INPUT_DIM = size(TRAIN_FEATURES[1], 1)
 DATA = [([x[:, idx] for x in TRAIN_FEATURES], TRAIN_LABELS[:, idx]) for idx  partition(shuffle(1:size(TRAIN_LABELS, 2)), 1000)]

 # ----------------------------------- MODEL --------------------------------------------
 model = Chain(
     RNN(INPUT_DIM, 128, relu),
     Dense(128, 10)
 )
 # --------------------------------- HELPER -----------------------------------------------
 function loss_fn(X, Y)
     Flux.reset!(model)
     out = [model(x) for x  X]
     Flux.Losses.logitcrossentropy(out[end], Y)
 end

 function accuracy(X, Y)
     Flux.reset!(model) # Only important for recurrent network
     out = [model(x) for x  X]
     mean(onecold(out[end], 0:9) .== Y)
 end

 θ = params(model)
 opt = Flux.ADAM()
 evalcb() = @show(accuracy(TEST_FEATURES, TEST_LABELS))
 # ----------------------------------- TRAIN -------------------------
 Flux.@epochs 30 Flux.train!(loss_fn, θ, DATA, opt, cb = Flux.throttle(evalcb, 5))

still doesn't work

@jeremiedb
Copy link
Contributor

Not on a computer right now, but I think you should remove the reset! from the loss function.
And therefore, stick to a custom training loop instead of train!

@alerem18
Copy link
Author

Not on a computer right now, but I think you should remove the reset! from the loss function. And therefore, stick to a custom training loop instead of train!

i found out if i delete model in loss and accuracy function, i get bad results else it's working as expected:
loss_fn(X, Y), accuracy(X, Y) ===> bad results
loss_fn(m, X, Y), accuracy(m, X, Y) ==> good results

can you explain why this happens because it's too weird

@mcabbott mcabbott added the RNN label Feb 16, 2023
@ToucheSir
Copy link
Member

Can you show the before and after code for that change? It's not immediately clear what the difference would be.

@CarloLucibello
Copy link
Member

@alerem18 if you manage to clarify what's the difference causing a bad result we can decide if we have an actual bug or not

@alerem18
Copy link
Author

loss_fn(X, Y), accuracy(X, Y) ===> bad results
loss_fn(m, X, Y), accuracy(m, X, Y) ==> good results

passing model thorough loss and accuracy functions will work as expected, if you don't pass it to those functions, you'll get bad results, model doesn't improve after a while, accuracy will be around 50-60%

@ToucheSir
Copy link
Member

What we're asking for is full code examples that show the good and bad results. Without that, loss_fn(X, Y) and loss_fn(m, X, Y) could be completely different functions for all we know. Having a complete example will allow us to run and try to reproduce the behaviour.

@alerem18
Copy link
Author

alerem18 commented Mar 12, 2023

using Flux
using Flux: gradient, logitcrossentropy, params, Momentum
using OneHotArrays: onecold, onehotbatch
using MLDatasets: MNIST
using Random: shuffle
using Statistics: mean
using Base.Iterators: partition

# ------------------- data --------------------------
train_x, train_y = MNIST(split=:train).features, MNIST(split=:train).targets
test_x, test_y = MNIST(split=:test).features, MNIST(split=:test).targets
train_y = onehotbatch(train_y, 0:9)
train_x = [x for x  eachslice(train_x, dims=2)]
test_x = [x for x  eachslice(test_x, dims=2)]
# ------------------ constants ---------------------
INPUT_SIZE = 28
NUM_CLASSES = 10
BATCH_SIZE = 1000
EPOCHS = 5
# ------------------ model --------------------------
model = Chain(
    RNN(INPUT_SIZE, 128, relu),
    RNN(128, 64, relu),
    Dense(64, NUM_CLASSES)
)

# ---------------- helper --------------------------
loss_fn(m, X, y) = logitcrossentropy([m(x) for x  X][end], y)
accuracy(m, X, y) = mean(onecold([m(x) for x  X][end], 0:9) .== y)
opt = Momentum()
θ = params(model)

# --------------- train -----------------------------
for epoch  1:EPOCHS
    for idx   partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
        Flux.reset!(model)
        X = [x[:, idx] for x  train_x]
        y = train_y[:, idx]
        gs = gradient(θ) do 
            loss_fn(model, X, y)
        end
        Flux.Optimise.update!(opt, θ, gs)
    end
    Flux.reset!(model)
    test_acc = accuracy(model, test_x, test_y)
    @info "Epoch : $epoch | accuracy : $test_acc"
end

[ Info: Epoch : 1 | accuracy : 0.3968
[ Info: Epoch : 2 | accuracy : 0.7918
[ Info: Epoch : 3 | accuracy : 0.896
[ Info: Epoch : 4 | accuracy : 0.9365
[ Info: Epoch : 5 | accuracy : 0.9465

edit loss and accuracy functions like below and you get this results

loss_fn(X, y) = logitcrossentropy([model(x) for x  X][end], y)
accuracy(X, y) = mean(onecold([model(x) for x  X][end], 0:9) .== y)

[ Info: Epoch : 1 | accuracy : 0.2795
[ Info: Epoch : 2 | accuracy : 0.4944
[ Info: Epoch : 3 | accuracy : 0.3561
[ Info: Epoch : 4 | accuracy : 0.5146
[ Info: Epoch : 5 | accuracy : 0.5229
[ Info: Epoch : 6 | accuracy : 0.5467
[ Info: Epoch : 7 | accuracy : 0.598
[ Info: Epoch : 8 | accuracy : 0.6085
[ Info: Epoch : 9 | accuracy : 0.5953
[ Info: Epoch : 10 | accuracy : 0.6038
[ Info: Epoch : 11 | accuracy : 0.6063
[ Info: Epoch : 12 | accuracy : 0.6336
[ Info: Epoch : 13 | accuracy : 0.65
[ Info: Epoch : 14 | accuracy : 0.6488
[ Info: Epoch : 15 | accuracy : 0.5951
[ Info: Epoch : 16 | accuracy : 0.5911
[ Info: Epoch : 17 | accuracy : 0.61
[ Info: Epoch : 18 | accuracy : 0.6357
[ Info: Epoch : 19 | accuracy : 0.6215
[ Info: Epoch : 20 | accuracy : 0.6576

@jeremiedb
Copy link
Contributor

Thanks, I can reproduce. Cause isn't obvious to me but the behavior seems to point that the reset! performed within the train loop fails to affect the effective model`s state used by the loss and accuracy functions.

In all cases, it appears safer to use the explicit reference to model for the loss and accuracy functions. It also looks like a an non obvious behavior that can lead to unexpected bad behavior, hence would be worth documenting if we could confirm the root cause.

@ToucheSir
Copy link
Member

A quick sanity check would be moving θ = params(model) inside the inner training loop and seeing if that makes a difference. It's not immediately obvious to me why it would, but might as well eliminate one possibility.

@jeremiedb
Copy link
Contributor

Unfortunately, no luck with adding training params instantiation within the training loop. The following results in the same accuracy plateau around 60%:

        ps = params(model)
        Flux.reset!(model)
        gs = gradient(ps) do 
            loss_fn2(X, y)
        end

@alerem18
Copy link
Author

In the modified code, the loss_fn and accuracy functions do not take the params of the model as input, and they call the model directly within the function to compute the loss and accuracy.

The params function is used to extract the trainable parameters of a model, which is necessary for computing gradients and updating the model parameters during training. When params is used, the optimizer is able to track the gradients of the model parameters and update them accordingly during optimization.

By not using params, the optimizer is not able to track the gradients of the model parameters correctly and this can lead to incorrect optimization and lower accuracy.

Therefore, not using params in the modified code is a mistake and can result in lower accuracy.

any thoughts?

@ToucheSir
Copy link
Member

Part of the "magic" of passing a Params to gradient is that the trainable parameters do not have to be directly passed to the loss function. Instead, Zygote will track trainable parameters by object ID (basically hashes of memory addresses) and accumulate gradients accordingly. This is why we call using params working with "implicit parameters".

The problem here is that something is causing the aforementioned tracking to not work. Ordinarily both versions of the code should behave similarly, so this is a bug. It's also why we've moving away from magical implicit params to directly passing the model/trainable params to gradient and the loss function: it's way less bug-prone, easier to understand for users and easier to debug for developers.

@jeremiedb
Copy link
Contributor

For reference, this is how you could use the new explicit gradient / Optimsers.jl mode:

loss_fn1(m, X, y) = logitcrossentropy([m(x) for x  X][end], y)
accuracy1(m, X, y) = mean(onecold([m(x) for x  X][end], 0:9) .== y)

rule = Flux.Optimisers.Adam()
opts = Flux.Optimisers.setup(rule, model);

for epoch  1:5
    for idx   partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
        X = [x[:, idx] for x  train_x]
        y = train_y[:, idx]
        Flux.reset!(model)
        gs = gradient(model) do m
            loss_fn1(m, X, y)
        end
        Flux.Optimisers.update!(opts, model, gs[1]);
    end
    Flux.reset!(model)
    test_acc = accuracy1(model, test_x, test_y)
    @info "Epoch : $epoch | accuracy : $test_acc"
end

@fujiehuang
Copy link

the RNN gradient with Zygote might have a bug. Here's my short test code. Keeping outputs in an array and in a scalar give me different gradients. How come?

using Flux 
using Random
Random.seed!(149)

layer1 = Flux.Recur(Flux.RNNCell(1 => 1, identity))

x = Float32[0.8, 0.9]
y = Float32(-0.7)

Flux.reset!(layer1)
e1, g1 = Flux.withgradient(layer1) do m
    yhat = 0.0
    for i in 1:2 
        yhat = m([x[i]])
    end
    loss = Flux.mse(yhat, y)
    println(loss)
    return loss 
end
println("flux gradients: ", g1[1])

Flux.reset!(layer1)
e2, g2 = Flux.withgradient(layer1) do m
    yhat = [m([x[i]]) for i in 1:2]
    loss = Flux.mse(yhat[end], y)
    println(loss)
    return loss 
end
println("flux gradients: ", g2[1])

@jeremiedb
Copy link
Contributor

jeremiedb commented Mar 16, 2023

There's effectively something fishy going on with the RNN gradients.

using Flux 
layer2 = Flux.Recur(Flux.RNNCell(1, 1, identity))
layer2.cell.Wi .= 5.0
layer2.cell.Wh .= 4.0
layer2.cell.b .= 0f0
layer2.cell.state0 .= 7.0
x = [[2f0], [3f0]]
Flux.reset!(layer2)
ps = Flux.params(layer2)
e2, g2 = Flux.withgradient(ps) do
    out = [layer2(xi) for xi in x]
    sum(out[2])
end

julia> g2[ps[1]]
1×1 Matrix{Float32}:
 3.0

julia> g2[ps[2]]
1×1 Matrix{Float32}:
 38.0

julia> g2[ps[3]]
1-element Fill{Float32}, with entry equal to 1.0

julia> g2[ps[4]] # nothing

Theoretical gradients are:

julia> ∇Wi = x[1] .* layer2.cell.Wh .+ x[2] 
1×1 Matrix{Float32}:
 11.0

julia> ∇Wh = 2 .* layer2.cell.Wh .* layer2.cell.state0 .+ x[1] .* layer2.cell.Wi 
1×1 Matrix{Float32}:
 66.0

julia> ∇b = layer2.cell.Wh .+ 1
1×1 Matrix{Float32}:
 5.0

julia> ∇state0 = layer2.cell.Wh .^ 2
1×1 Matrix{Float32}:
 16.0

Worst, the gradients are different (yet still wrong) if using the explicit mode :\

I tested on older version of Flux and things got even more weird. I got the same bad gradients going back to v0.11.4. However, when trying out of Julia 1.6.5... correct gradients with all tested Flux versions, v0.11.4 up to v0.13.4 and latest Zygote v0.6.58 (both implicit and explicit modes)!

The same bad gradients were observed on Julia 1.7.2 and 1.9.0-rc1.

So, it seems like something changed btween Julia v1.6 and v1.7 that had an impact on gradient correctness. Any idea @ToucheSir?

@ToucheSir
Copy link
Member

If I had to guess, something about lowering changed between those two versions. The more concerning part is that our test suite didn't catch this. I've always had a sinking feeling that https://github.com/FluxML/Flux.jl/blob/master/test/layers/recurrent.jl did not provide sufficient coverage, and unfortunately this only confirms that...

@jeremiedb
Copy link
Contributor

I'll open a PR by tomorrow to add the above gradients tests. I'm also disappointed not to have taken the time to manually validate those RNN gradients until now. Zygote is quite a footgun :\

@liuyxpp
Copy link

liuyxpp commented Sep 27, 2023

For reference, this is how you could use the new explicit gradient / Optimsers.jl mode:

loss_fn1(m, X, y) = logitcrossentropy([m(x) for x  X][end], y)
accuracy1(m, X, y) = mean(onecold([m(x) for x  X][end], 0:9) .== y)

rule = Flux.Optimisers.Adam()
opts = Flux.Optimisers.setup(rule, model);

for epoch  1:5
    for idx   partition(shuffle(1:size(train_y, 2)), BATCH_SIZE)
        X = [x[:, idx] for x  train_x]
        y = train_y[:, idx]
        Flux.reset!(model)
        gs = gradient(model) do m
            loss_fn1(m, X, y)
        end
        Flux.Optimisers.update!(opts, model, gs[1]);
    end
    Flux.reset!(model)
    test_acc = accuracy1(model, test_x, test_y)
    @info "Epoch : $epoch | accuracy : $test_acc"
end

In this implementation, the Flux.reset!(model) is outside of loss function. Does it mean the model initial state is preserved among a batch? I don't think it is the expected behavior for most of the cases. The model should be reset for each single sample not each single batch.

@ToucheSir
Copy link
Member

If you consider the initial state non-trainable, then I think it's mostly equivalent since other libraries are passing all zeros as the initial state. If you have a custom initial state however or want it to be trainable (which PyTorch at least does not appear to support directly), then it is not the same as you say. I'm unsure why the original design is the way it is (cc @mkschleg for possible theories), but reworking the initial state is one of those things we're investigating for our overhaul of the RNN API.

@jeremiedb
Copy link
Contributor

jeremiedb commented Nov 12, 2023

Regarding the initialization of initial state, although it may not be the common form encountered in PyTorch, this paper with LSTM author as co-author points to the relevance of learning the initial state (see section 5.1 at page 135). Also this blog post discussing it: https://r2rt.com/non-zero-initial-states-for-recurrent-neural-networks.html. I also had the vague souvenir that MXNet used to have learnable initial state as a feature, but couldn't confirm.

By applying reset!, the state of the model is set to a learnable inital state that will be applied to all observations in the input data X. Notice how the loss function iterates over all the "time-steps" of the input data: [m(x) for x ∈ X]. This means for the first timestep, each individual observation belonging to that batch will have a common hidden state input. Following that first timestep, the state of the model will be different for each observation of the batch. It's not advised to put reset! within a loss funciton as it's not a learnable operation, but an assignation one where the state of the model is assigned with the initial-state parameter, so it's ready to received a new batch-sequence data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants