Skip to content

Training resets (destroy) cloned (or saved) Neural Nets #949

Open
@aurium

Description

@aurium

What is wrong?

I'm trying to save a LSTM net to localStorage to reuse and continue the training after reloading a web page.
However, any little training resets the ANN after every reload.

Where does it happen?

On the Firefox web browser.

How do we replicate the issue?

  1. Create a LSTM net.
  2. Train it.
  3. Create a second LSTM net.
  4. Load the first net data onto the second.
  5. See both outputs the same.
  6. Train again the first without problem, the error rate evolves normally.
  7. Train the second, the error rate explodes.
    (I'll add a test code bellow)

Expected behavior (i.e. solution)

The cloned (or loaded after save) net should continues evolving from the point where the original stopped.

Version information

Nodejs: null

Browser: Firefox 128

Brain.js: https://unpkg.com/[email protected]

How important is this (1-5)?

5

Other Comments

Test code:

<script src="http://unpkg.com/brain.js"></script>
<script>
const net = new brain.recurrent.LSTM({ hiddenLayers: [60, 60] })
net.maxPredictionLength = 100

const trainData = [
  'doe, a deer, a female deer',
  'ray, a drop of golden sun',
  'me, a name I call myself',
]

// First train
net.train(trainData, {
  iterations: 5000,
  log: true,
  logPeriod: 500,
  learningRate: 0.2,
})

// Clone the net:
const net2 = new brain.recurrent.LSTM({ hiddenLayers: [60, 60] })
net2.fromJSON(net.toJSON())

// Both output the same text:
console.log('ray 1:', net.run('ray'))
console.log('ray 2:', net2.run('ray'))

// More training, start from the last error rate:
net.train(trainData, {
  iterations: 30,
  log: true,
  logPeriod: 10,
  learningRate: 0.2,
})

// More training to the clone:
net2.train(trainData, {
  iterations: 30,
  log: true,
  logPeriod: 10,
  learningRate: 0.2,
})
// (???) That started with a BIG error rate!

// The first reduced the quality, but the second is crazy:
console.log('ray 1:', net.run('ray'))
console.log('ray 2:', net2.run('ray'))
</script>

Example output:

iterations: 0, training error: Infinity
iterations: 500, training error: 0.01295498762785306
iterations: 1000, training error: 0.01216566726130864
iterations: 1500, training error: 0.012144481239444045
iterations: 2000, training error: 0.012128375937731972
iterations: 2500, training error: 186331.85451823566
iterations: 3000, training error: 0.013161829605093137
iterations: 3500, training error: 0.012466912897993345
iterations: 4000, training error: 0.013913231326531402
iterations: 4500, training error: 0.012443924643591718
ray 1: , a drop of golden sun
ray 2: , a drop of golden sun
iterations: 0, training error: 0.01232396260956039    <-- re-Training the original
iterations: 10, training error: 0.012362680081652325
iterations: 20, training error: 0.056502526340651955
iterations: 0, training error: Infinity               <-- Training the clone
iterations: 10, training error: 403136894.7168553
iterations: 20, training error: 469362.6213782614
ray 1: , a f g n of go go go goldnfseuff go go go goldnfseuff go go go golden
ray 2: eeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions