Skip to content

Commit 09b8a27

Browse files
authored
Add a size argument to the input layer (#916)
1 parent c8c1c31 commit 09b8a27

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

lib/model/nns/layer/index.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ import Tensor from '../../../util/tensor.js'
181181
* { type: 'huber' } |
182182
* { type: 'identity' } |
183183
* { type: 'include', net: NeuralNetwork | object[], input_to?: string, train?: boolean } |
184-
* { type: 'input', name?: string } |
184+
* { type: 'input', name?: string, size?: number[] } |
185185
* { type: 'is_inf' } |
186186
* { type: 'is_nan' } |
187187
* { type: 'isigmoid', a?: number, alpha?: number } |

lib/model/nns/layer/input.js

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Layer from './base.js'
1+
import Layer, { NeuralnetworkLayerException } from './base.js'
22
import Matrix from '../../../util/matrix.js'
33
import Tensor from '../../../util/tensor.js'
44

@@ -9,10 +9,12 @@ export default class InputLayer extends Layer {
99
/**
1010
* @param {object} config object
1111
* @param {string} [config.name] Name of the layer
12+
* @param {number[]} [config.size] Size of the layer
1213
*/
13-
constructor({ name = null, ...rest }) {
14+
constructor({ name = null, size = null, ...rest }) {
1415
super(rest)
1516
this._name = name
17+
this._size = size
1618
}
1719

1820
bind({ input }) {
@@ -35,6 +37,13 @@ export default class InputLayer extends Layer {
3537
} else {
3638
this._o = new Matrix(1, 1, input)
3739
}
40+
41+
if (this._size) {
42+
const inSize = this._o.sizes
43+
if (inSize.length !== this._size.length || this._size.some((v, i) => v != null && v !== inSize[i])) {
44+
throw new NeuralnetworkLayerException(`Invalid input size`, [this])
45+
}
46+
}
3847
}
3948

4049
calc() {

tests/lib/model/nns/layer/input.test.js

+29
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,35 @@ describe('layer', () => {
1010
expect(layer).toBeDefined()
1111
})
1212

13+
describe('bind', () => {
14+
test.each([
15+
[null, null],
16+
[2, null],
17+
[null, 3],
18+
[2, 3],
19+
])('valid dim length [%p, %p]', (...size) => {
20+
expect.assertions(0)
21+
const layer = new InputLayer({ size })
22+
23+
const x = Matrix.randn(2, 3).toArray()
24+
layer.bind({ input: x })
25+
})
26+
27+
test('invalid dim length', () => {
28+
const layer = new InputLayer({ size: [null, null, null] })
29+
30+
const x = Matrix.randn(2, 3).toArray()
31+
expect(() => layer.bind({ input: x })).toThrow('Invalid input size')
32+
})
33+
34+
test('invalid dim value', () => {
35+
const layer = new InputLayer({ size: [3, 2] })
36+
37+
const x = Matrix.randn(2, 3).toArray()
38+
expect(() => layer.bind({ input: x })).toThrow('Invalid input size')
39+
})
40+
})
41+
1342
describe('calc', () => {
1443
test('scalar', () => {
1544
const layer = new InputLayer({})

0 commit comments

Comments
 (0)