@@ -469,30 +469,10 @@ const binaryLayers = {
469
469
return v
470
470
} ,
471
471
} ,
472
- equal : {
473
- calc : ( x1 , x2 ) => x1 === x2 ,
474
- gradFunc : ( ) => 0 ,
475
- } ,
476
- greater : {
477
- calc : ( x1 , x2 ) => x1 > x2 ,
478
- gradFunc : ( ) => 0 ,
479
- } ,
480
- greater_or_equal : {
481
- calc : ( x1 , x2 ) => x1 >= x2 ,
482
- gradFunc : ( ) => 0 ,
483
- } ,
484
472
left_bitshift : {
485
473
calc : ( x1 , x2 ) => x1 << x2 ,
486
474
gradFunc : ( ) => 0 ,
487
475
} ,
488
- less : {
489
- calc : ( x1 , x2 ) => x1 < x2 ,
490
- gradFunc : ( ) => 0 ,
491
- } ,
492
- less_or_equal : {
493
- calc : ( x1 , x2 ) => x1 <= x2 ,
494
- gradFunc : ( ) => 0 ,
495
- } ,
496
476
max : {
497
477
calc : Math . max ,
498
478
gradFunc : ( k , x ) => {
@@ -563,3 +543,44 @@ const binaryLayers = {
563
543
for ( const name of Object . keys ( binaryLayers ) ) {
564
544
buildBinaryLayer ( name , binaryLayers [ name ] . calc , binaryLayers [ name ] . gradFunc )
565
545
}
546
+
547
+ const buildCompareLayer = ( name , calcFunc ) => {
548
+ class TempLayer extends Layer {
549
+ calc ( ...x ) {
550
+ this . _i = x
551
+ this . _o = x [ 0 ] . copy ( )
552
+ this . _o . map ( ( ) => true )
553
+ for ( let i = 1 ; i < x . length ; i ++ ) {
554
+ const xi = x [ i - 1 ] . copy ( )
555
+ xi . broadcastOperate ( x [ i ] , calcFunc )
556
+ this . _o . broadcastOperate ( xi , ( a , b ) => a && b )
557
+ }
558
+ return this . _o
559
+ }
560
+
561
+ grad ( ) {
562
+ const bi = this . _i . map ( x => {
563
+ const bi = x . copy ( )
564
+ bi . fill ( 0 )
565
+ return bi
566
+ } )
567
+ return bi
568
+ }
569
+ }
570
+ Object . defineProperty ( TempLayer , 'name' , {
571
+ value : name . split ( '_' ) . reduce ( ( s , nm ) => s + nm [ 0 ] . toUpperCase ( ) + nm . substring ( 1 ) . toLowerCase ( ) , '' ) + 'Layer' ,
572
+ } )
573
+ TempLayer . registLayer ( name )
574
+ }
575
+
576
+ const compareLayers = {
577
+ equal : { calc : ( x1 , x2 ) => x1 === x2 } ,
578
+ greater : { calc : ( x1 , x2 ) => x1 > x2 } ,
579
+ greater_or_equal : { calc : ( x1 , x2 ) => x1 >= x2 } ,
580
+ less : { calc : ( x1 , x2 ) => x1 < x2 } ,
581
+ less_or_equal : { calc : ( x1 , x2 ) => x1 <= x2 } ,
582
+ }
583
+
584
+ for ( const name of Object . keys ( compareLayers ) ) {
585
+ buildCompareLayer ( name , compareLayers [ name ] . calc )
586
+ }
0 commit comments