Skip to content

Commit d74f8d5

Browse files
update
Created using spr 1.3.4
1 parent 25a0a30 commit d74f8d5

2 files changed

Lines changed: 45 additions & 43 deletions

File tree

gematria/granite/python/gnn_model_base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,15 @@ def __init__(
226226
shape=global_feature_shape, dtype=global_feature_dtype or self.dtype
227227
)
228228

229-
self._graph_network = self._create_graph_network_modules()
230-
assert self._graph_network is not None
231-
232229
self._num_message_passing_iterations = num_message_passing_iterations
233230
self._graph_module_residual_connections = graph_module_residual_connections
234231
self._graph_module_layer_normalization = graph_module_layer_normalization
235232

233+
def initialize(self):
234+
super().initialize()
235+
self._graph_network = self._create_graph_network_modules()
236+
assert self._graph_network is not None
237+
236238
self._norm_layers = {}
237239
self._residual_layers = {}
238240
nodes_residual_shape = None

gematria/granite/python/gnn_model_base_test.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def __init__(self, *, decoder_residual_connection=False, **kwargs):
229229
**kwargs,
230230
)
231231

232+
def initialize(self):
233+
super().initialize()
232234
self._linear_layer = tf_keras.layers.Dense(
233235
self.num_tasks, activation='linear'
234236
)
@@ -476,18 +478,17 @@ def test_train_seq2seq_single_task(self, loss_type, loss_normalization):
476478
def test_train_seq2num_encoder_decoder_model(
477479
self, loss_type, loss_normalization
478480
):
481+
model = TestEncoderDecoderGnnModel(
482+
graph_module_layer_normalization=False,
483+
loss_normalization=loss_normalization,
484+
loss_type=loss_type,
485+
learning_rate=0.01,
486+
)
479487
with mock.patch(
480488
'tf_keras.layers.LayerNormalization',
481489
side_effect=tf_keras.layers.LayerNormalization,
482490
) as tf_keras_layer_norm:
483-
model = TestEncoderDecoderGnnModel(
484-
graph_module_layer_normalization=False,
485-
loss_normalization=loss_normalization,
486-
loss_type=loss_type,
487-
learning_rate=0.01,
488-
)
489-
490-
model.initialize()
491+
model.initialize()
491492
self.assertEqual(
492493
tf_keras_layer_norm.call_args_list,
493494
[
@@ -504,6 +505,14 @@ def test_train_seq2num_encoder_decoder_model(
504505
def test_train_seq2seq_encoder_decoder_model(
505506
self, loss_type, loss_normalization
506507
):
508+
model = TestEncoderDecoderGnnModel(
509+
graph_module_layer_normalization=True,
510+
graph_module_residual_connections=False,
511+
loss_normalization=loss_normalization,
512+
loss_type=loss_type,
513+
use_deltas=True,
514+
learning_rate=0.01,
515+
)
507516
with (
508517
mock.patch(
509518
'tf_keras.layers.LayerNormalization',
@@ -514,16 +523,7 @@ def test_train_seq2seq_encoder_decoder_model(
514523
side_effect=model_blocks.ResidualConnectionLayer,
515524
) as residual_connection_layer,
516525
):
517-
model = TestEncoderDecoderGnnModel(
518-
graph_module_layer_normalization=True,
519-
graph_module_residual_connections=False,
520-
loss_normalization=loss_normalization,
521-
loss_type=loss_type,
522-
use_deltas=True,
523-
learning_rate=0.01,
524-
)
525-
526-
model.initialize()
526+
model.initialize()
527527
# NOTE(ondrasej): tf.math.add is called only when adding residual
528528
# connections. Since they are disabled in this test case, we should not see
529529
# any calls to this function.
@@ -563,6 +563,12 @@ def test_train_seq2seq_encoder_decoder_model(
563563
self.check_training_model(model)
564564

565565
def test_train_seq2seq_model_with_residual_connections(self):
566+
model = TestEncoderDecoderGnnModel(
567+
graph_module_layer_normalization=True,
568+
graph_module_residual_connections=True,
569+
use_deltas=True,
570+
learning_rate=0.01,
571+
)
566572
with (
567573
mock.patch(
568574
'gematria.model.python.model_blocks.ResidualConnectionLayer',
@@ -573,14 +579,7 @@ def test_train_seq2seq_model_with_residual_connections(self):
573579
side_effect=tf_keras.layers.Dense,
574580
) as tf_keras_dense,
575581
):
576-
model = TestEncoderDecoderGnnModel(
577-
graph_module_layer_normalization=True,
578-
graph_module_residual_connections=True,
579-
use_deltas=True,
580-
learning_rate=0.01,
581-
)
582-
583-
model.initialize()
582+
model.initialize()
584583
self.assertEqual(
585584
residual_connection_layer.call_args_list,
586585
[
@@ -618,6 +617,14 @@ def test_train_seq2seq_model_with_residual_connections(self):
618617
def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
619618
self,
620619
):
620+
model = TestEncoderDecoderGnnModel(
621+
graph_module_layer_normalization=False,
622+
graph_module_residual_connections=False,
623+
decoder_residual_connection=True,
624+
use_deltas=True,
625+
learning_rate=0.01,
626+
)
627+
621628
with (
622629
mock.patch(
623630
'gematria.model.python.model_blocks.ResidualConnectionLayer',
@@ -628,14 +635,7 @@ def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
628635
side_effect=tf_keras.layers.Dense,
629636
) as tf_keras_dense,
630637
):
631-
model = TestEncoderDecoderGnnModel(
632-
graph_module_layer_normalization=False,
633-
graph_module_residual_connections=False,
634-
decoder_residual_connection=True,
635-
use_deltas=True,
636-
learning_rate=0.01,
637-
)
638-
model.initialize()
638+
model.initialize()
639639
self.assertEqual(
640640
residual_connection_layer.call_args_list,
641641
[
@@ -648,22 +648,22 @@ def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
648648
tf_keras_dense.call_args_list,
649649
[
650650
mock.call(
651-
activation=tf_keras.activations.linear,
652-
name='residual_connection_2_0_nodes_transformation',
653651
units=5,
652+
activation=tf_keras.activations.linear,
654653
use_bias=False,
654+
name='residual_connection_2_0_nodes_transformation',
655655
),
656656
mock.call(
657-
activation=tf_keras.activations.linear,
658-
name='residual_connection_2_0_edges_transformation',
659657
units=6,
658+
activation=tf_keras.activations.linear,
660659
use_bias=False,
660+
name='residual_connection_2_0_edges_transformation',
661661
),
662662
mock.call(
663-
activation=tf_keras.activations.linear,
664-
name='residual_connection_2_0_globals_transformation',
665663
units=4,
664+
activation=tf_keras.activations.linear,
666665
use_bias=False,
666+
name='residual_connection_2_0_globals_transformation',
667667
),
668668
mock.call(1, activation='linear'),
669669
],

0 commit comments

Comments
 (0)