Skip to content

Commit 994a9a8

Browse files
update
Created using spr 1.3.4
2 parents 9f87a66 + ca77619 commit 994a9a8

3 files changed

Lines changed: 61 additions & 65 deletions

File tree

gematria/granite/python/gnn_model_base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,15 @@ def __init__(
226226
shape=global_feature_shape, dtype=global_feature_dtype or self.dtype
227227
)
228228

229-
self._graph_network = self._create_graph_network_modules()
230-
assert self._graph_network is not None
231-
232229
self._num_message_passing_iterations = num_message_passing_iterations
233230
self._graph_module_residual_connections = graph_module_residual_connections
234231
self._graph_module_layer_normalization = graph_module_layer_normalization
235232

233+
def initialize(self):
234+
super().initialize()
235+
self._graph_network = self._create_graph_network_modules()
236+
assert self._graph_network is not None
237+
236238
self._norm_layers = {}
237239
self._residual_layers = {}
238240
nodes_residual_shape = None

gematria/granite/python/gnn_model_base_test.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def __init__(self, *, decoder_residual_connection=False, **kwargs):
229229
**kwargs,
230230
)
231231

232+
def initialize(self):
233+
super().initialize()
232234
self._linear_layer = tf_keras.layers.Dense(
233235
self.num_tasks, activation='linear'
234236
)
@@ -476,18 +478,17 @@ def test_train_seq2seq_single_task(self, loss_type, loss_normalization):
476478
def test_train_seq2num_encoder_decoder_model(
477479
self, loss_type, loss_normalization
478480
):
481+
model = TestEncoderDecoderGnnModel(
482+
graph_module_layer_normalization=False,
483+
loss_normalization=loss_normalization,
484+
loss_type=loss_type,
485+
learning_rate=0.01,
486+
)
479487
with mock.patch(
480488
'tf_keras.layers.LayerNormalization',
481489
side_effect=tf_keras.layers.LayerNormalization,
482490
) as tf_keras_layer_norm:
483-
model = TestEncoderDecoderGnnModel(
484-
graph_module_layer_normalization=False,
485-
loss_normalization=loss_normalization,
486-
loss_type=loss_type,
487-
learning_rate=0.01,
488-
)
489-
490-
model.initialize()
491+
model.initialize()
491492
self.assertEqual(
492493
tf_keras_layer_norm.call_args_list,
493494
[
@@ -504,6 +505,14 @@ def test_train_seq2num_encoder_decoder_model(
504505
def test_train_seq2seq_encoder_decoder_model(
505506
self, loss_type, loss_normalization
506507
):
508+
model = TestEncoderDecoderGnnModel(
509+
graph_module_layer_normalization=True,
510+
graph_module_residual_connections=False,
511+
loss_normalization=loss_normalization,
512+
loss_type=loss_type,
513+
use_deltas=True,
514+
learning_rate=0.01,
515+
)
507516
with (
508517
mock.patch(
509518
'tf_keras.layers.LayerNormalization',
@@ -514,16 +523,7 @@ def test_train_seq2seq_encoder_decoder_model(
514523
side_effect=model_blocks.ResidualConnectionLayer,
515524
) as residual_connection_layer,
516525
):
517-
model = TestEncoderDecoderGnnModel(
518-
graph_module_layer_normalization=True,
519-
graph_module_residual_connections=False,
520-
loss_normalization=loss_normalization,
521-
loss_type=loss_type,
522-
use_deltas=True,
523-
learning_rate=0.01,
524-
)
525-
526-
model.initialize()
526+
model.initialize()
527527
# NOTE(ondrasej): tf.math.add is called only when adding residual
528528
# connections. Since they are disabled in this test case, we should not see
529529
# any calls to this function.
@@ -563,6 +563,12 @@ def test_train_seq2seq_encoder_decoder_model(
563563
self.check_training_model(model)
564564

565565
def test_train_seq2seq_model_with_residual_connections(self):
566+
model = TestEncoderDecoderGnnModel(
567+
graph_module_layer_normalization=True,
568+
graph_module_residual_connections=True,
569+
use_deltas=True,
570+
learning_rate=0.01,
571+
)
566572
with (
567573
mock.patch(
568574
'gematria.model.python.model_blocks.ResidualConnectionLayer',
@@ -573,14 +579,7 @@ def test_train_seq2seq_model_with_residual_connections(self):
573579
side_effect=tf_keras.layers.Dense,
574580
) as tf_keras_dense,
575581
):
576-
model = TestEncoderDecoderGnnModel(
577-
graph_module_layer_normalization=True,
578-
graph_module_residual_connections=True,
579-
use_deltas=True,
580-
learning_rate=0.01,
581-
)
582-
583-
model.initialize()
582+
model.initialize()
584583
self.assertEqual(
585584
residual_connection_layer.call_args_list,
586585
[
@@ -618,6 +617,14 @@ def test_train_seq2seq_model_with_residual_connections(self):
618617
def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
619618
self,
620619
):
620+
model = TestEncoderDecoderGnnModel(
621+
graph_module_layer_normalization=False,
622+
graph_module_residual_connections=False,
623+
decoder_residual_connection=True,
624+
use_deltas=True,
625+
learning_rate=0.01,
626+
)
627+
621628
with (
622629
mock.patch(
623630
'gematria.model.python.model_blocks.ResidualConnectionLayer',
@@ -628,14 +635,7 @@ def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
628635
side_effect=tf_keras.layers.Dense,
629636
) as tf_keras_dense,
630637
):
631-
model = TestEncoderDecoderGnnModel(
632-
graph_module_layer_normalization=False,
633-
graph_module_residual_connections=False,
634-
decoder_residual_connection=True,
635-
use_deltas=True,
636-
learning_rate=0.01,
637-
)
638-
model.initialize()
638+
model.initialize()
639639
self.assertEqual(
640640
residual_connection_layer.call_args_list,
641641
[
@@ -648,22 +648,22 @@ def test_train_seq2seq_model_with_residual_connections_with_linear_transform(
648648
tf_keras_dense.call_args_list,
649649
[
650650
mock.call(
651-
activation=tf_keras.activations.linear,
652-
name='residual_connection_2_0_nodes_transformation',
653651
units=5,
652+
activation=tf_keras.activations.linear,
654653
use_bias=False,
654+
name='residual_connection_2_0_nodes_transformation',
655655
),
656656
mock.call(
657-
activation=tf_keras.activations.linear,
658-
name='residual_connection_2_0_edges_transformation',
659657
units=6,
658+
activation=tf_keras.activations.linear,
660659
use_bias=False,
660+
name='residual_connection_2_0_edges_transformation',
661661
),
662662
mock.call(
663-
activation=tf_keras.activations.linear,
664-
name='residual_connection_2_0_globals_transformation',
665663
units=4,
664+
activation=tf_keras.activations.linear,
666665
use_bias=False,
666+
name='residual_connection_2_0_globals_transformation',
667667
),
668668
mock.call(1, activation='linear'),
669669
],

gematria/granite/python/graph_builder_model_base.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,21 @@ def __init__(
132132
**kwargs: Additional keyword arguments are passed to the constructor of
133133
the base class.
134134
"""
135-
token_model.TokenModel.__init__(
136-
self,
135+
# NOTE(ondrasej): We set the node/edge feature dtypes to int32. They are
136+
# indices to the token list/edge type; an int32 should be sufficient for all
137+
# our use cases and fixing the type will make it easier to move the array
138+
# construction to the C++ code if needed in the future. Similarly for the
139+
# graph index dtype.
140+
super().__init__(
141+
node_feature_shape=(),
142+
node_feature_dtype=tf.dtypes.int32,
143+
edge_feature_shape=(),
144+
edge_feature_dtype=tf.dtypes.int32,
145+
global_feature_shape=(len(tokens),),
146+
global_feature_dtype=tf.dtypes.int32,
147+
graph_index_dtype=tf.dtypes.int32,
137148
tokens=tokens,
138-
out_of_vocabulary_behavior=kwargs['out_of_vocabulary_behavior'],
139-
dtype=kwargs['dtype'],
149+
**kwargs,
140150
)
141151

142152
self._instruction_features = None
@@ -157,24 +167,6 @@ def __init__(
157167
)
158168
self._num_annotations = len(self._annotation_names_list)
159169

160-
# NOTE(ondrasej): We set the node/edge feature dtypes to int32. They are
161-
# indices to the token list/edge type; an int32 should be sufficient for all
162-
# our use cases and fixing the type will make it easier to move the array
163-
# construction to the C++ code if needed in the future. Similarly for the
164-
# graph index dtype.
165-
gnn_model_base.GnnModelBase.__init__(
166-
self,
167-
node_feature_shape=(),
168-
node_feature_dtype=tf.dtypes.int32,
169-
edge_feature_shape=(),
170-
edge_feature_dtype=tf.dtypes.int32,
171-
global_feature_shape=(len(tokens),),
172-
global_feature_dtype=tf.dtypes.int32,
173-
graph_index_dtype=tf.dtypes.int32,
174-
tokens=tokens,
175-
**kwargs,
176-
)
177-
178170
special_tokens = np.array(
179171
(
180172
self._batch_graph_builder.immediate_token,
@@ -220,9 +212,11 @@ def _make_batch_feed_dict(self) -> model_base.FeedDict:
220212
feed_dict['instruction_node_mask'] = np.array(
221213
self._batch_graph_builder.instruction_node_mask, dtype=bool
222214
)
215+
self._instruction_node_mask = feed_dict['instruction_node_mask']
223216
feed_dict['instruction_annotations'] = (
224217
self._batch_graph_builder.instruction_annotations
225218
)
219+
self._instruction_annotations = feed_dict['instruction_annotations']
226220
return feed_dict
227221

228222
# @Override

0 commit comments

Comments
 (0)