@@ -11,6 +11,7 @@ import Matrix from '../../../../lib/util/matrix.js'
11
11
import ComputationalGraph from '../../../../lib/model/nns/graph.js'
12
12
13
13
import Layer from '../../../../lib/model/nns/layer/base.js'
14
+ import Tensor from '../../../../lib/util/tensor.js'
14
15
15
16
describe ( 'Computational Graph' , ( ) => {
16
17
test ( 'constructor' , ( ) => {
@@ -193,6 +194,40 @@ describe('Computational Graph', () => {
193
194
expect ( y [ i ] ) . toBeCloseTo ( Math . tanh ( x . value [ i ] ) )
194
195
}
195
196
} )
197
+
198
+ test ( 'complex layers' , async ( ) => {
199
+ const graph = new ComputationalGraph ( )
200
+ graph . add ( Layer . fromObject ( { type : 'input' , size : [ null , 6 , 6 , 3 ] } ) )
201
+ graph . add ( Layer . fromObject ( { type : 'conv' , kernel : 3 } ) )
202
+ graph . add ( Layer . fromObject ( { type : 'max_pool' , kernel : 2 } ) )
203
+ graph . add ( Layer . fromObject ( { type : 'relu' } ) )
204
+ graph . add ( Layer . fromObject ( { type : 'flatten' } ) )
205
+ graph . add ( Layer . fromObject ( { type : 'full' , out_size : 10 } ) )
206
+ graph . add ( Layer . fromObject ( { type : 'tanh' } ) , 'v' )
207
+ graph . add ( Layer . fromObject ( { type : 'pau' } ) )
208
+ graph . add ( Layer . fromObject ( { type : 'tanh' } ) , 'pau' )
209
+ graph . add ( Layer . fromObject ( { type : 'apl' } ) , 'apl' , 'v' )
210
+ graph . add ( Layer . fromObject ( { type : 'add' } ) , null , [ 'pau' , 'apl' ] )
211
+ graph . add ( Layer . fromObject ( { type : 'output' } ) )
212
+
213
+ const x = Tensor . randn ( [ 100 , 6 , 6 , 3 ] )
214
+ graph . bind ( { input : x } )
215
+ graph . calc ( )
216
+ const t = graph . outputNodes [ 0 ] . outputValue
217
+
218
+ const buf = await graph . toONNX ( )
219
+ session = await ort . InferenceSession . create ( buf )
220
+
221
+ const xten = new ort . Tensor ( 'float32' , x . value , x . sizes )
222
+ const out = await session . run ( { _input : xten } )
223
+ const yten = out . _add
224
+ expect ( yten . dims ) . toEqual ( [ 100 , 10 ] )
225
+ const y = await yten . getData ( true )
226
+
227
+ for ( let i = 0 ; i < y . length ; i ++ ) {
228
+ expect ( y [ i ] ) . toBeCloseTo ( t . value [ i ] )
229
+ }
230
+ } )
196
231
} )
197
232
198
233
describe ( 'add' , ( ) => {
0 commit comments