Skip to content

Commit 9c44911

Browse files
authored
Enhance InputLayer to support default values and update tests (#948)
1 parent f149bac commit 9c44911

File tree

3 files changed

+36
-2
lines changed

3 files changed

+36
-2
lines changed

Diff for: lib/model/nns/layer/index.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ export { default as VarLayer } from './variance.js'
187187
* { type: 'huber' } |
188188
* { type: 'identity' } |
189189
* { type: 'include', net: NeuralNetwork | object[], input_to?: string, train?: boolean } |
190-
* { type: 'input', name?: string, size?: number[] } |
190+
* { type: 'input', name?: string, size?: number[], value?: number | number[] | number[][] | nunber[][][] | number[][][][] | Matrix | Tensor } |
191191
* { type: 'is_inf' } |
192192
* { type: 'is_nan' } |
193193
* { type: 'isigmoid', a?: number, alpha?: number } |

Diff for: lib/model/nns/layer/input.js

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ export default class InputLayer extends Layer {
1010
* @param {object} config object
1111
* @param {string} [config.name] Name of the layer
1212
* @param {number[]} [config.size] Size of the layer
13+
* @param {number | number[] | number[][] | nunber[][][] | number[][][][] | Matrix | Tensor} [config.value] Default value
1314
*/
14-
constructor({ name = null, size = null, ...rest }) {
15+
constructor({ name = null, size = null, value, ...rest }) {
1516
super(rest)
1617
this._name = name
1718
this._size = size
19+
this._value = value
1820
}
1921

2022
bind({ input }) {
@@ -27,6 +29,9 @@ export default class InputLayer extends Layer {
2729
) {
2830
input = input[this._name]
2931
}
32+
if (input == null) {
33+
input = this._value
34+
}
3035
if (Array.isArray(input)) {
3136
this._o = Tensor.fromArray(input)
3237
if (this._o.dimension === 2) {
@@ -57,6 +62,7 @@ export default class InputLayer extends Layer {
5762
type: 'input',
5863
name: this._name,
5964
size: this._size?.concat(),
65+
value: this._value instanceof Matrix || this._value instanceof Tensor ? this._value.toArray() : this._value,
6066
}
6167
}
6268
}

Diff for: tests/lib/model/nns/layer/input.test.js

+28
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,26 @@ describe('layer', () => {
101101
}
102102
}
103103
})
104+
105+
test('default value', () => {
106+
const x = 1
107+
const layer = new InputLayer({ value: x })
108+
109+
layer.bind({})
110+
const y = layer.calc()
111+
expect(y.sizes).toEqual([1, 1])
112+
expect(y.at(0, 0)).toBe(1)
113+
})
114+
115+
test('default value and bind value', () => {
116+
const x = 1
117+
const layer = new InputLayer({ value: x })
118+
119+
layer.bind({ input: 2 })
120+
const y = layer.calc()
121+
expect(y.sizes).toEqual([1, 1])
122+
expect(y.at(0, 0)).toBe(2)
123+
})
104124
})
105125

106126
describe('grad', () => {
@@ -135,6 +155,14 @@ describe('layer', () => {
135155
const obj = layer.toObject()
136156
expect(obj).toEqual({ type: 'input', name: 'in', size: [null, 10] })
137157
})
158+
159+
test('matrix value', () => {
160+
const mat = Matrix.randn(1, 10)
161+
const layer = new InputLayer({ name: 'in', size: [null, 10], value: mat })
162+
163+
const obj = layer.toObject()
164+
expect(obj).toEqual({ type: 'input', name: 'in', size: [null, 10], value: mat.toArray() })
165+
})
138166
})
139167

140168
test('fromObject', () => {

0 commit comments

Comments
 (0)