Skip to content

Commit 7ba037d

Browse files
authored
Fix constructor of APL layer (#920)
1 parent 95bf2af commit 7ba037d

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

lib/model/nns/layer/apl.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@ export default class AdaptivePiecewiseLinearLayer extends Layer {
1010
* @param {number | number[]} [config.a] Variables control the slopes of the linear segments
1111
* @param {number | number[]} [config.b] Variables determine the locations of the hinges
1212
*/
13-
constructor({ s = 2, a = 0.1, b = 0, ...rest }) {
13+
constructor({ s = 2, a = null, b = 0, ...rest }) {
1414
super(rest)
1515
this._s = s
1616
if (Array.isArray(a)) {
1717
this._a = a
1818
} else {
1919
this._a = []
2020
for (let k = 0; k < s; k++) {
21-
this._a[k] = Math.random()
21+
this._a[k] = a ?? Math.random()
2222
}
2323
}
24-
this._b = Array.isArray(b) ? b : Array(s).fill(0)
24+
this._b = Array.isArray(b) ? b : Array(s).fill(b)
2525
this._l2_decay = 0.001
2626
}
2727

tests/lib/model/nns/layer/apl.test.js

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { jest } from '@jest/globals'
1+
import { expect, jest, test } from '@jest/globals'
22
jest.retryTimes(3)
33

44
import NeuralNetwork from '../../../../../lib/model/neuralnetwork.js'
@@ -8,9 +8,16 @@ import Tensor from '../../../../../lib/util/tensor.js'
88
import APLLayer from '../../../../../lib/model/nns/layer/apl.js'
99

1010
describe('layer', () => {
11-
test('construct', () => {
12-
const layer = new APLLayer({})
13-
expect(layer).toBeDefined()
11+
describe('construct', () => {
12+
test('default', () => {
13+
const layer = new APLLayer({})
14+
expect(layer).toBeDefined()
15+
})
16+
17+
test('number', () => {
18+
const layer = new APLLayer({ a: 2, b: 3 })
19+
expect(layer).toBeDefined()
20+
})
1421
})
1522

1623
describe('calc', () => {

0 commit comments

Comments
 (0)