Skip to content

Commit 200c470

Browse files
authored
Accept object for add method of ComputationalGraph class (#902)
1 parent 908e892 commit 200c470

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

lib/model/neuralnetwork.js

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import Matrix from '../util/matrix.js'
22
import Tensor from '../util/tensor.js'
3-
import Layer from './nns/layer/base.js'
43
export { default as Layer } from './nns/layer/base.js'
54

65
import ComputationalGraph from './nns/graph.js'
@@ -61,13 +60,11 @@ export default class NeuralNetwork {
6160
}
6261
const graph = new ComputationalGraph()
6362
for (const cn of const_numbers) {
64-
const cl = Layer.fromObject({ type: 'const', value: [[cn]] })
65-
graph.add(cl, `__const_number_${cn}`, [])
63+
graph.add({ type: 'const', value: [[cn]] }, `__const_number_${cn}`, [])
6664
}
6765

6866
for (const l of layers) {
69-
const cl = Layer.fromObject(l)
70-
graph.add(cl, l.name, l.input)
67+
graph.add(l, l.name, l.input)
7168
}
7269

7370
return new NeuralNetwork(graph, optimizer)

lib/model/nns/graph.js

+9-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@ import { InputLayer, OutputLayer } from './layer/index.js'
55
import ONNXImporter from './onnx/onnx_importer.js'
66

77
/**
8-
* @typedef {import("./layer/index").PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
8+
* @ignore
9+
* @typedef {import("./layer/index").PlainLayerObject} PlainLayerObject
10+
*/
11+
/**
12+
* @typedef {PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
913
* @typedef {object} Node
1014
* @property {Layer} layer Layer
1115
* @property {string} name Name of the node
@@ -127,11 +131,14 @@ export default class ComputationalGraph {
127131

128132
/**
129133
* Add a layer.
130-
* @param {Layer} layer Added layer
134+
* @param {Layer | PlainLayerObject} layer Added layer
131135
* @param {string} [name] Node name
132136
* @param {string[] | string} [inputs] Input node names for the added layer
133137
*/
134138
add(layer, name, inputs = undefined) {
139+
if (!(layer instanceof Layer)) {
140+
layer = Layer.fromObject(layer)
141+
}
135142
let parentinfos = []
136143
if (!inputs) {
137144
if (this._nodes.length > 0) {

tests/lib/model/nns/graph.test.js

+10
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ describe('Computational Graph', () => {
175175
expect(graph.nodes[1].parents[0].subscript).toBeNull()
176176
})
177177

178+
test('object', () => {
179+
const graph = new ComputationalGraph()
180+
graph.add({ type: 'input' })
181+
graph.add({ type: 'tanh' })
182+
183+
expect(graph.nodes[1].parents).toHaveLength(1)
184+
expect(graph.nodes[1].parents[0].index).toBe(0)
185+
expect(graph.nodes[1].parents[0].subscript).toBeNull()
186+
})
187+
178188
test('string input', () => {
179189
const graph = new ComputationalGraph()
180190
graph.add(Layer.fromObject({ type: 'input' }), 'in')

0 commit comments

Comments
 (0)