@@ -229,6 +229,8 @@ def __init__(self, *, decoder_residual_connection=False, **kwargs):
229229 ** kwargs ,
230230 )
231231
232+ def initialize (self ):
233+ super ().initialize ()
232234 self ._linear_layer = tf_keras .layers .Dense (
233235 self .num_tasks , activation = 'linear'
234236 )
@@ -476,18 +478,17 @@ def test_train_seq2seq_single_task(self, loss_type, loss_normalization):
476478 def test_train_seq2num_encoder_decoder_model (
477479 self , loss_type , loss_normalization
478480 ):
481+ model = TestEncoderDecoderGnnModel (
482+ graph_module_layer_normalization = False ,
483+ loss_normalization = loss_normalization ,
484+ loss_type = loss_type ,
485+ learning_rate = 0.01 ,
486+ )
479487 with mock .patch (
480488 'tf_keras.layers.LayerNormalization' ,
481489 side_effect = tf_keras .layers .LayerNormalization ,
482490 ) as tf_keras_layer_norm :
483- model = TestEncoderDecoderGnnModel (
484- graph_module_layer_normalization = False ,
485- loss_normalization = loss_normalization ,
486- loss_type = loss_type ,
487- learning_rate = 0.01 ,
488- )
489-
490- model .initialize ()
491+ model .initialize ()
491492 self .assertEqual (
492493 tf_keras_layer_norm .call_args_list ,
493494 [
@@ -504,6 +505,14 @@ def test_train_seq2num_encoder_decoder_model(
504505 def test_train_seq2seq_encoder_decoder_model (
505506 self , loss_type , loss_normalization
506507 ):
508+ model = TestEncoderDecoderGnnModel (
509+ graph_module_layer_normalization = True ,
510+ graph_module_residual_connections = False ,
511+ loss_normalization = loss_normalization ,
512+ loss_type = loss_type ,
513+ use_deltas = True ,
514+ learning_rate = 0.01 ,
515+ )
507516 with (
508517 mock .patch (
509518 'tf_keras.layers.LayerNormalization' ,
@@ -514,16 +523,7 @@ def test_train_seq2seq_encoder_decoder_model(
514523 side_effect = model_blocks .ResidualConnectionLayer ,
515524 ) as residual_connection_layer ,
516525 ):
517- model = TestEncoderDecoderGnnModel (
518- graph_module_layer_normalization = True ,
519- graph_module_residual_connections = False ,
520- loss_normalization = loss_normalization ,
521- loss_type = loss_type ,
522- use_deltas = True ,
523- learning_rate = 0.01 ,
524- )
525-
526- model .initialize ()
526+ model .initialize ()
527527 # NOTE(ondrasej): tf.math.add is called only when adding residual
528528 # connections. Since they are disabled in this test case, we should not see
529529 # any calls to this function.
@@ -563,6 +563,12 @@ def test_train_seq2seq_encoder_decoder_model(
563563 self .check_training_model (model )
564564
565565 def test_train_seq2seq_model_with_residual_connections (self ):
566+ model = TestEncoderDecoderGnnModel (
567+ graph_module_layer_normalization = True ,
568+ graph_module_residual_connections = True ,
569+ use_deltas = True ,
570+ learning_rate = 0.01 ,
571+ )
566572 with (
567573 mock .patch (
568574 'gematria.model.python.model_blocks.ResidualConnectionLayer' ,
@@ -573,14 +579,7 @@ def test_train_seq2seq_model_with_residual_connections(self):
573579 side_effect = tf_keras .layers .Dense ,
574580 ) as tf_keras_dense ,
575581 ):
576- model = TestEncoderDecoderGnnModel (
577- graph_module_layer_normalization = True ,
578- graph_module_residual_connections = True ,
579- use_deltas = True ,
580- learning_rate = 0.01 ,
581- )
582-
583- model .initialize ()
582+ model .initialize ()
584583 self .assertEqual (
585584 residual_connection_layer .call_args_list ,
586585 [
@@ -618,6 +617,14 @@ def test_train_seq2seq_model_with_residual_connections(self):
618617 def test_train_seq2seq_model_with_residual_connections_with_linear_transform (
619618 self ,
620619 ):
620+ model = TestEncoderDecoderGnnModel (
621+ graph_module_layer_normalization = False ,
622+ graph_module_residual_connections = False ,
623+ decoder_residual_connection = True ,
624+ use_deltas = True ,
625+ learning_rate = 0.01 ,
626+ )
627+
621628 with (
622629 mock .patch (
623630 'gematria.model.python.model_blocks.ResidualConnectionLayer' ,
@@ -628,14 +635,7 @@ def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
628635 side_effect = tf_keras .layers .Dense ,
629636 ) as tf_keras_dense ,
630637 ):
631- model = TestEncoderDecoderGnnModel (
632- graph_module_layer_normalization = False ,
633- graph_module_residual_connections = False ,
634- decoder_residual_connection = True ,
635- use_deltas = True ,
636- learning_rate = 0.01 ,
637- )
638- model .initialize ()
638+ model .initialize ()
639639 self .assertEqual (
640640 residual_connection_layer .call_args_list ,
641641 [
@@ -648,22 +648,22 @@ def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
648648 tf_keras_dense .call_args_list ,
649649 [
650650 mock .call (
651- activation = tf_keras .activations .linear ,
652- name = 'residual_connection_2_0_nodes_transformation' ,
653651 units = 5 ,
652+ activation = tf_keras .activations .linear ,
654653 use_bias = False ,
654+ name = 'residual_connection_2_0_nodes_transformation' ,
655655 ),
656656 mock .call (
657- activation = tf_keras .activations .linear ,
658- name = 'residual_connection_2_0_edges_transformation' ,
659657 units = 6 ,
658+ activation = tf_keras .activations .linear ,
660659 use_bias = False ,
660+ name = 'residual_connection_2_0_edges_transformation' ,
661661 ),
662662 mock .call (
663- activation = tf_keras .activations .linear ,
664- name = 'residual_connection_2_0_globals_transformation' ,
665663 units = 4 ,
664+ activation = tf_keras .activations .linear ,
666665 use_bias = False ,
666+ name = 'residual_connection_2_0_globals_transformation' ,
667667 ),
668668 mock .call (1 , activation = 'linear' ),
669669 ],
0 commit comments