Skip to content

Error when batch_size different from 1 in NeuralNetworkRegressor #225

Open
@MathNog

Description

@MathNog

When I pass batch_size as a parameter to the NeuralNetworkRegressor() the model can´t be fitted because of a dimension mismatch.

I have written the following code:

mutable struct LSTMBuilder <: MLJFlux.Builder
	input_size :: Int
	num_units :: Dict
    num_layers :: Int
end
function MLJFlux.build(lstm::LSTMBuilder, rng, n_in, n_out)

    input_size, num_units, num_layers = lstm.input_size, lstm.num_units, lstm.num_layers
    init = Flux.glorot_uniform(rng)
    Random.seed!(1234)
    layers = [LSTM(n_in,num_units[1]), Dropout(0.1)]
    for i in 1:num_layers-1
        layers = vcat(layers,[LSTM(num_units[i],num_units[i+1]), Dropout(0.1)])
    end
    layers = vcat(layers, Dense(num_units[num_layers],n_out))
    Random.seed!(1234)
    model = Chain(layers)

    return model
end
model = NeuralNetworkRegressor(builder=LSTMBuilder(60, 4, 2),
                        rng = Random.GLOBAL_RNG,
                        epochs = 200,
                        loss = Flux.mse,
                        optimiser = ADAM(0.001),
                        batch_size = 16)

And the error messagem when training it is:

[ Info: Training machine(JackknifeRegressor(model = NeuralNetworkRegressor(builder = LSTMBuilder(input_size = 60, …), …), …), …).
Optimising neural net: 100%[=========================] Time: 0:00:03
┌ Error: Problem fitting the machine machine(JackknifeRegressor(model = NeuralNetworkRegressor(builder = LSTMBuilder(input_size = 60, …), …), …), …). 
└ @ MLJBase C:\Users\matheuscn.ELE\.julia\packages\MLJBase\5cxU0\src\machines.jl:682
[ Info: Running type checks... 
[ Info: Type checks okay. 
ERROR: DimensionMismatch: array could not be broadcast to match destination

I suspect that this error is caused by the fact that there is no Flux.reset!() after each batch update inside the training loop.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions